51 Commits

Author SHA1 Message Date
Roberto
0833db239c fix(scouts): fetch single Gmail message instead of bulk in fetch_content
Replace bulk GmailClient.fetch_messages() + linear search with a direct
service.users().messages().get(format="full") call. Adds _extract_plain_text_body
helper for recursive MIME part walking. Update test to patch _get_gmail_service.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-16 05:39:39 +02:00
Roberto
11b31e5814 feat(scouts): add Gmail OAuth scout-setup routes
Three new endpoints under /api/v1/scouts/oauth/gmail/:
  GET  /authorize       — PKCE consent URL for gmail.readonly + gmail.modify scopes
  GET  /web-callback    — bounces to adiuvai:// deep link (excluded from schema)
  POST /callback        — exchanges code, encrypts + stores token, triggers setup_watch

State TTL 10 min, in-memory (same pattern as auth.py _pending_states).
Redirect URI base derived from existing OAUTH_REDIRECT_URI setting.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-16 04:54:10 +02:00
Roberto
cb274c9728 feat(scouts): add cron-fallback poll + gmail watch renewal ticks 2026-05-16 04:36:49 +02:00
Roberto
d3497a1908 feat(scouts): gmail pub/sub webhook with JWT verification 2026-05-16 04:31:57 +02:00
Roberto
0c0299808c feat(scouts): real triage LLM call via scout-triage-system prompt 2026-05-16 04:26:16 +02:00
Roberto
d1016fd65a feat(scouts): register GmailConnector at startup
Adds GmailConnector registration to the FastAPI lifespan startup block,
making it available via the connector registry for the ScoutEngine
and any other startup-time consumers.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-16 04:18:33 +02:00
Roberto
c559754532 feat(scouts): add GmailConnector
Implements GmailConnector — the first concrete SourceConnector.
Wraps existing GmailClient + low-level Gmail API service for metadata-only
fetch, trash archive, incremental history polling, and Pub/Sub watch setup.
Adds GMAIL_PUBSUB_TOPIC setting (empty string default for dev).
Adds 3 passing unit tests (mocked API, no real credentials required).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-16 04:18:07 +02:00
Roberto
9f21d5ae8f feat(scouts): deliver_pending drains queue and sends scout_proposal frames
Add ScoutEngine.deliver_pending(user_id, ws) that queries status='queued'
rows, fetches metadata via the registered connector, sends scout_proposal
WS frames, and flips status to 'delivered'. Add ack_proposal(proposal_id)
that flips 'delivered' -> 'acked' (idempotent). Wire both into device_ws.py:
deliver_pending fires as a background task after device_hello + register;
scout_proposal_ack frames dispatch to ack_proposal in the message loop.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-16 03:45:04 +02:00
Roberto
699bba3a30 feat(schemas): add scout_proposal + scout_proposal_ack WS frame types 2026-05-16 03:10:04 +02:00
Roberto
1364b9ba37 feat(scouts): add ScoutEngine triage + queue insertion 2026-05-16 02:55:18 +02:00
Roberto
27df8c0a8d feat(scouts): add connector registry 2026-05-16 02:45:12 +02:00
Roberto
4933f8055c feat(scouts): add SourceConnector protocol and item types 2026-05-16 02:41:40 +02:00
Roberto
ac33ac1c0d feat(scouts): add ScoutTriageQueue table + cloud_scout_configs gmail fields
Tasks 12+13 of Phase 2 — first new infra after rename.
Alembic 008 creates scout_triage_queue with unique constraint on
(scout_id, source_msg_ref) and partial index on expires_at for active
rows. Adds four columns to cloud_scout_configs: auto_trash_spam,
gmail_history_id, gmail_watch_expires_at, device_inactivity_pause_days.
SQLAlchemy model ScoutTriageQueue added; CloudScoutConfig updated to
match. Imports extended with UniqueConstraint and text.
2026-05-16 02:36:20 +02:00
Roberto
fbd308d288 refactor(ws): rename agent_ids to scout_ids in device_hello frame
WsDeviceHello.agent_ids → scout_ids in Pydantic schema,
device_ws.py handler, and all test fixtures (test_device_ws,
test_ws_unified, test_memory_middleware). Also fixes stale
CloudAgentConfig reference in gmail.py docstring.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-16 01:50:15 +02:00
Roberto
105cf52083 refactor(schemas): rename Agent* schemas and WS frame types to Scout*
Rename all Pydantic models referring to the scout subsystem:
AgentConfig → ScoutConfig, ContentTypeConfig → ScoutContentTypeConfig,
AgentCatalogItem → ScoutCatalogItem, AgentCreationCheckRequest/Response →
ScoutCreationCheckRequest/Response, AgentTriggerRequest → ScoutTriggerRequest,
AgentRunLogResponse → ScoutRunLogResponse.

LLM-helper agent schemas in app/agents/* are untouched.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-16 00:58:14 +02:00
Roberto
c2b27d4fb7 refactor(core): rename agent_runner/session_buffer/registry to scout_* 2026-05-16 00:27:50 +02:00
Roberto
b92e72b685 refactor(routes): rename /agents and /agent-setup to /scouts and /scout-setup
Rename routes/agents.py → routes/scouts.py and routes/agent_setup.py →
routes/scout_setup.py. Update APIRouter prefix/tags in scouts.py to
/scouts and scouts. Update main.py router registration, device_ws.py
import, and test_journey_v2.py import/patch paths to use scout_setup.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-16 00:00:07 +02:00
Roberto
1ccb0282fe refactor(models): rename Agent classes to Scout
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-15 23:52:29 +02:00
Roberto
1a20c11e86 feat(db): rename agents to scouts (alembic 007) 2026-05-15 23:36:28 +02:00
Roberto
70c19d3064 chore(contextual): purge residual floating WsFrame defs + output_formatter branch
After M6.5 deletion of run_floating_stream and the frame dispatch,
WsFrameType.floating_request/floating_domain, WsFloatingRequest,
WsFloatingDomain, WsFloatingScope, WsDomain, and the StreamFormatter's
floating_domain branch were left as dead protocol surface. Remove them,
along with the corresponding test cases in test_schemas_v3.py and
test_output_formatter.py.
2026-05-15 18:56:29 +02:00
Roberto
886730b47e test(contextual): remove floating-specific tests
Replaced by tests/test_contextual_*.py in M3.
No dedicated test_floating_*.py files existed; floating test
functions were embedded in test_deep_agent.py and test_ws_unified.py
and have been removed from those files.
2026-05-15 18:53:08 +02:00
Roberto
052c7e3741 refactor(contextual): drop floating WS frame, runner, and prompt fallback
contextual_request + contextual_scope_update are the only WS
flows for ad-hoc contextual chat now. Floating system prompt
constant removed; Langfuse 'floating_system' is deleted in a
separate manual step. Also removes floating-agent LLM slot from
llm.py and the associated LLM_MODEL_FLOATING_AGENT setting entry.
2026-05-15 18:53:01 +02:00
Roberto
d63fd5f3b9 fix(contextual): narrow tool palette + forbid legacy read tools
Smoke trace 0b46841484ba7d024ed9f8d5ac8b1df0 showed the agent
defaulting to list_projects + get_project for a 'summarize
project Nexus' query, returning a shallow row without aiSummary
or tasks/notes. The legacy read tools were exposed via
*PROJECT_TOOLS / *TASK_TOOLS spreading.

Now _contextual_tools exposes exactly:
- get_page_details (sole read; supports per-entity + list views)
- create_task, update_task
- create_note
- create_timeline

Prompt rule 2 explicitly forbids the legacy reads, and the test
asserts they are excluded from the palette.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-15 18:23:55 +02:00
Roberto
5e42b2abb1 fix(contextual): inject date_context + language in run_contextual_stream
Use _build_system_prompt helper so the contextual agent gets the
same system-prompt slots as home/floating runners — most importantly
{date_context} so the agent can reason about due dates when
creating/updating tasks.

Also makes the session_id contract on run_contextual_stream explicit
(was reading via context['_debug']) and tightens the tool-list test.
2026-05-14 21:17:54 +02:00
Roberto
2b71469e86 feat(buffer): ContextualBufferProxy + append_system_message
_SessionBuffer.append_system_message(user_id, session_id, text) injects a
synthetic SystemMessage into the named session slot (creating it if absent).

ContextualBufferProxy closes over user_id + session_id so call sites need
only call proxy.append_system_message(text).

get_session_buffer(user_id, session_id, channel) in device_ws returns a
ContextualBufferProxy, keeping the test-patchable function signature intact.
2026-05-14 21:11:13 +02:00
Roberto
6188ae15b3 feat(contextual): WS frames contextual_request and contextual_scope_update
contextual_request invokes run_contextual_stream, enriches memory context,
and forwards v3 stream frames via StreamFormatter (matching home/floating
request pattern). Episode stored after response.

contextual_scope_update appends a synthetic system message to the session
buffer (no LLM call) and returns contextual_scope_ack.

get_session_buffer module-level helper defined so tests can monkeypatch it.
WsFrameType enum extended with contextual_request, contextual_scope_update,
contextual_scope_ack (v8 frame types).

NOTE: test_contextual_ws.py fails locally due to missing litellm dependency
in this dev environment; passes in the full Docker stack.
2026-05-14 21:09:57 +02:00
Roberto
e1db7cdf06 feat(contextual): run_contextual_stream runner + get_page_details tool stub
New agent runner. Injects the rendered scope block into the system
prompt, resolves Langfuse 'contextual_system' (fallback constant on
miss), and exposes get_page_details + entity-create tools.
Note-edit tools (propose_note_edit) intentionally excluded — next sprint.

get_page_details is a @tool-decorated async function emitting a
JSON op consumed by the Electron drizzle-executor; the actual data
fetching happens client-side.

_contextual_tools() assembles the safe tool palette. Tools follow the
existing @tool decorator pattern from langchain_core.tools.

NOTE: test_run_contextual.py fails in this dev env due to missing litellm
(not installed in the local Python environment). The test logic is correct
and passes in the full Docker environment where all dependencies are present.
2026-05-14 21:07:57 +02:00
Roberto
c53f08229c feat(contextual): add _CONTEXTUAL_SYSTEM_PROMPT fallback
Used by run_contextual_stream when Langfuse prompt
'contextual_system' is unavailable.
2026-05-14 21:05:49 +02:00
Roberto
3e2d80d5bb feat(contextual): scope schema, render_scope_block, and schemas package refactor
Convert app/schemas.py → app/schemas/__init__.py so the contextual
module can live at app/schemas/contextual.py while keeping all existing
'from app.schemas import ...' calls unchanged.

ContextualScope mirrors the renderer's camelCase payload via
alias_generator=to_camel. render_scope_block produces a single-paragraph
human-readable summary injected into the contextual agent system prompt.
4 tests, all passing.
2026-05-14 21:04:20 +02:00
Roberto
cc0e258e8c fix(api): WS index frames accept both camelCase and snake_case keys (Electron toSnakeCase compat) 2026-05-13 08:58:46 +02:00
Roberto
12e203e63d fix(api): multi-project manifest lists projects even with zero indexed files 2026-05-12 18:10:57 +02:00
Roberto
ffcd7390f0 feat(api): pagination + search + PDF/DOCX extract in folder agent tools 2026-05-12 17:31:43 +02:00
Roberto
91e880f9d4 fix(api): home agent falls back to multi-project folder manifest when no project_id 2026-05-12 16:54:47 +02:00
Roberto
7d47ca54be feat(api): emit Langfuse generation traces for folder indexer 2026-05-12 16:40:20 +02:00
Roberto
956fa88853 feat(api): multi-project folder manifest for daily brief
Add build_brief_multi_project_manifest() to deep_agent.py that fetches
all project folder manifests via execute_on_client and keeps the top 5
most-recently-modified files per project. Wire into run_home_brief in
brief_agent.py, injecting the <linked_folders> block into the system
prompt alongside FOLDER_TOOLS.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:40:47 +02:00
Roberto
fb2f59ccea feat(api): inject folder manifest into home agent when project context active
Add optional project_id param to run_home_stream. When set, fetch the linked
folder manifest via _fetch_project_manifest and prepend the <linked_folder>
block to the system prompt. Also build an explicit tools list that extends
_all_tools_for_user with FOLDER_TOOLS so the home agent can read folder
files. device_ws._handle_home_request extracts project_id / projectId from
the home_request frame and forwards it to the runner.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:32:20 +02:00
Roberto
56dbb7f4cd feat(api): inject folder manifest into task brief agent
Add _fetch_project_manifest helper that calls read_project_folder_manifest
via execute_on_client. Wire it into run_task_brief_research_stream (new
optional project_id param) so the <linked_folder> block is prepended to the
system prompt when the task belongs to a linked project. Also bind
FOLDER_TOOLS into the task-brief tool palette so the agent can read folder
files. device_ws extracts project_id / projectId from the task_brief_request
frame and forwards it.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:31:21 +02:00
Roberto
506f517851 feat(api): manifest formatter with token-budget truncation 2026-05-12 11:28:13 +02:00
Roberto
520c186991 feat(api): scoped read_project_folder_file tool with traversal guard
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:26:02 +02:00
Roberto
582bf27deb feat(api): WS index_session frames + handlers
Add six v7 WsFrameType enum members (index_session_start/cancel/batch,
index_file_result/progress/done), wire dispatch in device_ws message loop,
and implement _handle_index_session_start/cancel/file_batch with per-file
summarisation, token accounting, and quota enforcement.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:22:20 +02:00
Roberto
2aeb453229 feat(api): PDF + DOCX extraction in folder indexer
Add pypdf/python-docx deps, _extract_pdf_text/_extract_docx_text helpers,
and summarize_pdf/summarize_docx wrappers that delegate to summarize_text.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:15:17 +02:00
Roberto
b7a4edac90 feat(api): folder_indexer.summarize_image via gpt-4o-mini vision
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:09:37 +02:00
Roberto
822b4cd8b1 feat(api): folder_indexer.summarize_text via gpt-4o-mini 2026-05-12 11:05:43 +02:00
Roberto
ab24fc4c91 feat(api): POST /billing/quota/check endpoint
Pre-flight quota check for folder_index. Returns 402 with reason
when file cap or monthly token budget would be exceeded; 200 {"ok": true}
otherwise. Also adds auth_headers_free fixture to conftest.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 09:14:56 +02:00
Roberto
a98e99f7a2 feat(api): folder quota helpers with atomic token usage
Implements check_folder_quota and add_token_usage in app/billing/quota.py
with dialect-aware upsert (pg_insert on PostgreSQL, read-then-write on SQLite).
Adds test_user_free/test_user_power fixtures and db alias to conftest.py.
6 new tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 08:23:22 +02:00
Roberto
a0ff285bcd feat(api): tier features for folder integration
Add folder_max_files and folder_monthly_tokens to all four tier dicts
in FEATURES, and add get_feature_value() helper to TierManager.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:39:36 +02:00
Roberto
177c1a87dd feat(api): MonthlyTokenUsage model + AgentRunLog.tokens_used
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:30:33 +02:00
Roberto
441a4ea05c chore(api): fix stale Revises comment in folder migration 2026-05-12 07:21:13 +02:00
Roberto
a693a64bf5 feat(api): add migration for folder token tracking
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:16:23 +02:00
Roberto
67562b8092 Add task brief research agent: Stage 1 deep-research + canvas draft emission
- run_task_brief_research() runner with brief-specific tool set and max_steps=12
- New agents: client_agent (list_clients, get_client) and relations_agent (query_relations)
- search_associative tool wrapping MemoryMiddleware semantic search
- BRIEF_RESEARCH_TOOLS constant: read-only task/project/note/timeline + memory + client/relations
- canvas block extraction in output_formatter (splits visible text from <canvas> draft)
- device_ws.py: task_brief_research request type; emits canvas_draft mutation on stream_end
- Stage 2 briefMode: briefing_context injected into floating system prompt when present
- briefingContext kwarg wired through compile_prompt call chain

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-04 15:09:58 +02:00
Roberto
6f4c68b359 Update note management from db vector to index 2026-04-30 00:11:17 +02:00
60 changed files with 4658 additions and 1128 deletions

View File

@@ -56,6 +56,10 @@ LLM_MODEL_CLOUD_PROCESSOR=
# A small model (e.g. gpt-4o-mini) is sufficient. # A small model (e.g. gpt-4o-mini) is sufficient.
# LLM_MODEL_BRIEF_AGENT= # LLM_MODEL_BRIEF_AGENT=
# Task-brief-agent — per-task deep research (Stage 1 executive assistant).
# Needs tool-use + reasoning; a capable model recommended (e.g. gpt-4o, gemini-2.5-flash).
# LLM_MODEL_TASK_BRIEF_AGENT=
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat. # Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
LLM_MODEL_SETUP_AGENT= LLM_MODEL_SETUP_AGENT=

View File

@@ -0,0 +1,41 @@
"""Rename agents to scouts.
Revision ID: 007
Revises: d6e3f4a5b6c7
Create Date: 2026-05-15
Renames the entire agents subsystem identifiers to scouts.
Pre-1.0 — no data preservation concerns beyond ALTER TABLE rename.
"""
from typing import Sequence, Union
from alembic import op
revision: str = "007"
down_revision: Union[str, None] = "d6e3f4a5b6c7"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Tables
op.rename_table("local_agent_configs", "local_scout_configs")
op.rename_table("cloud_agent_configs", "cloud_scout_configs")
op.rename_table("agent_run_logs", "scout_run_logs")
# Columns
op.alter_column("local_scout_configs", "agent_config", new_column_name="scout_config")
op.alter_column("scout_run_logs", "agent_id", new_column_name="scout_id")
op.alter_column("scout_run_logs", "agent_type", new_column_name="scout_type")
def downgrade() -> None:
op.alter_column("scout_run_logs", "scout_type", new_column_name="agent_type")
op.alter_column("scout_run_logs", "scout_id", new_column_name="agent_id")
op.alter_column("local_scout_configs", "scout_config", new_column_name="agent_config")
op.rename_table("scout_run_logs", "agent_run_logs")
op.rename_table("cloud_scout_configs", "cloud_agent_configs")
op.rename_table("local_scout_configs", "local_agent_configs")

View File

@@ -0,0 +1,59 @@
"""Scout triage queue + cloud_scout_configs alterations.
Revision ID: 008
Revises: 007
Create Date: 2026-05-16
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "008"
down_revision: Union[str, None] = "007"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"scout_triage_queue",
sa.Column("id", sa.Uuid(as_uuid=False), primary_key=True),
sa.Column("user_id", sa.Uuid(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True),
sa.Column("scout_id", sa.Uuid(as_uuid=False), sa.ForeignKey("cloud_scout_configs.id", ondelete="CASCADE"), nullable=False),
sa.Column("source_type", sa.String(50), nullable=False),
sa.Column("source_msg_ref", sa.String(255), nullable=False),
sa.Column("triage_verdict", sa.String(20), nullable=False),
sa.Column("triage_reason", sa.Text, nullable=True),
sa.Column("status", sa.String(20), nullable=False, server_default="queued"),
sa.Column("triaged_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
sa.Column("delivered_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("acked_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.UniqueConstraint("scout_id", "source_msg_ref", name="uq_scout_triage_queue_scout_msg"),
)
op.create_index("ix_scout_triage_queue_user_status", "scout_triage_queue", ["user_id", "status"])
op.create_index(
"ix_scout_triage_queue_expires_active",
"scout_triage_queue",
["expires_at"],
postgresql_where=sa.text("status != 'acked'"),
)
op.add_column("cloud_scout_configs", sa.Column("auto_trash_spam", sa.Boolean(), nullable=False, server_default=sa.text("false")))
op.add_column("cloud_scout_configs", sa.Column("gmail_history_id", sa.String(64), nullable=True))
op.add_column("cloud_scout_configs", sa.Column("gmail_watch_expires_at", sa.DateTime(timezone=True), nullable=True))
op.add_column("cloud_scout_configs", sa.Column("device_inactivity_pause_days", sa.Integer(), nullable=False, server_default="14"))
def downgrade() -> None:
op.drop_column("cloud_scout_configs", "device_inactivity_pause_days")
op.drop_column("cloud_scout_configs", "gmail_watch_expires_at")
op.drop_column("cloud_scout_configs", "gmail_history_id")
op.drop_column("cloud_scout_configs", "auto_trash_spam")
op.drop_index("ix_scout_triage_queue_expires_active", table_name="scout_triage_queue")
op.drop_index("ix_scout_triage_queue_user_status", table_name="scout_triage_queue")
op.drop_table("scout_triage_queue")

View File

@@ -0,0 +1,46 @@
"""Add token tracking columns for folder integration.
Revision ID: d6e3f4a5b6c7
Revises: 006
Create Date: 2026-05-11 00:00:00.000000
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = "d6e3f4a5b6c7"
down_revision: Union[str, None] = "006"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"agent_run_logs",
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
)
op.create_table(
"monthly_token_usage",
sa.Column("user_id", UUID(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("year_month", sa.String(7), nullable=False),
sa.Column("feature", sa.String(64), nullable=False),
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
sa.PrimaryKeyConstraint("user_id", "year_month", "feature"),
)
op.create_index(
"ix_monthly_token_usage_user_month",
"monthly_token_usage",
["user_id", "year_month"],
)
def downgrade() -> None:
op.drop_index("ix_monthly_token_usage_user_month", table_name="monthly_token_usage")
op.drop_table("monthly_token_usage")
op.drop_column("agent_run_logs", "tokens_used")

View File

@@ -0,0 +1,52 @@
"""Client agent — read-only tools for the clients table."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
@tool
async def list_clients(search: str = "", limit: int = 20) -> str:
"""List clients, optionally filtered by a name/email substring search.
search: optional substring to match against client name or email.
limit: max rows to return (default 20).
"""
filters: dict[str, Any] = {"limit": limit}
if search:
filters["search"] = search
result = await execute_on_client(action="select", table="clients", filters=filters)
rows = result.get("rows", [])
if not rows:
return "No clients found."
lines = [
f"- {r.get('name', '?')} (id: {r.get('id')}, email: {r.get('email', '')}, "
f"company: {r.get('company', '')})"
for r in rows
]
return f"Found {len(rows)} client(s):\n" + "\n".join(lines)
@tool
async def get_client(id: str) -> str:
"""Get full details for one client by UUID.
id: the client's UUID.
"""
if not id:
return "Client id is required."
result = await execute_on_client(action="get", table="clients", data={"id": id})
row = result.get("row") or result.get("rows", [None])[0] if result else None
if not row:
return f"Client '{id}' not found."
return f"Client details:\n{json.dumps(row, ensure_ascii=False, indent=2)}"
CLIENT_TOOLS: list[Any] = [list_clients, get_client]

168
app/agents/folder_agent.py Normal file
View File

@@ -0,0 +1,168 @@
"""Scoped file-read and search tools for the project folder feature."""
from __future__ import annotations
from langchain_core.tools import tool
from app.core.folder_indexer import _extract_docx_text, _extract_pdf_text
from app.core.ws_context import execute_on_client
# Cap returned slice size to keep tool output under control.
_MAX_RETURN_CHARS = 50_000
_MAX_SEARCH_MATCHES = 20
def _is_unsafe_path(rel: str) -> bool:
if not rel:
return True
norm = rel.replace("\\", "/")
if norm.startswith("/"):
return True
# Windows drive letter
if len(rel) >= 2 and rel[1] == ":":
return True
parts = norm.split("/")
return ".." in parts
async def _fetch_file(project_id: str, relative_path: str, offset: int, length: int) -> dict:
"""Return the raw Electron tool_result dict for a file read."""
return await execute_on_client(
action="read_project_folder_file",
data={
"projectId": project_id,
"relativePath": relative_path,
"offset": offset,
"length": length,
},
)
def _decode(result: dict) -> tuple[str, str, int]:
"""Decode a tool_result into (text, kind, total_size). For pdf/docx,
extracts text from base64. For images, returns a placeholder string.
For text, content is already a sliced utf-8 string.
"""
kind = result.get("kind", "text")
content = result.get("content", "") or ""
total = int(result.get("totalSize", 0) or 0)
if kind == "image":
return ("[Image file — cannot be navigated as text. See manifest summary.]", kind, total)
if kind == "pdf":
return (_extract_pdf_text(content), kind, total)
if kind == "docx":
return (_extract_docx_text(content), kind, total)
return (content, kind, total)
@tool
async def read_project_folder_file(
project_id: str,
relative_path: str,
offset: int = 0,
length: int = _MAX_RETURN_CHARS,
) -> str:
"""Read a slice of a file inside the project's linked folder.
Args:
project_id: project ID.
relative_path: path relative to the linked folder root.
offset: char offset to start reading from (0 = beginning).
length: max chars to return. Default 50000. Use smaller values to save tokens.
Returns text content slice with a header showing position. Header tells you
when more content is available; call again with the suggested next offset.
For PDF / DOCX files the backend extracts text first, then applies offset/length
on the extracted text. For images returns a placeholder; navigate with the
manifest summary instead.
"""
if _is_unsafe_path(relative_path):
return "Access denied"
result = await _fetch_file(project_id, relative_path, offset, length)
text, kind, total_size = _decode(result)
if not text and kind in ("missing", "error"):
return f"File not found or unreadable: {relative_path}"
if kind in ("pdf", "docx"):
# Backend extracted full text — apply offset/length on chars.
sliced = text[offset:offset + length]
slice_end = min(offset + length, len(text))
header = (
f"[file={relative_path} kind={kind} offset={offset} end={slice_end} "
f"totalChars={len(text)}]"
)
if slice_end < len(text):
header += f"\n[More content available — call again with offset={slice_end}.]"
return header + "\n" + sliced
if kind == "text":
slice_end = offset + len(text)
header = (
f"[file={relative_path} kind=text offset={offset} end={slice_end} "
f"totalBytes={total_size}]"
)
if slice_end < total_size:
header += f"\n[More content available — call again with offset={slice_end}.]"
return header + "\n" + text
# image or unknown
return text
@tool
async def search_project_folder_file(
project_id: str,
relative_path: str,
query: str,
context_lines: int = 3,
) -> str:
"""Search a project folder file for a query string (case-insensitive substring).
Args:
project_id: project ID.
relative_path: path relative to the linked folder root.
query: text to search for.
context_lines: number of lines of context around each match (default 3).
Returns matching line ranges with surrounding context and 1-based line numbers.
Capped at 20 matches; if more exist the header shows the total.
Works on text, code, markdown, PDF (extracted), and DOCX (extracted).
Images and binary files are not searchable.
"""
if _is_unsafe_path(relative_path):
return "Access denied"
if not query:
return "Empty query."
# For text we still need full file; pass length=very large.
result = await _fetch_file(project_id, relative_path, offset=0, length=10_000_000)
text, kind, _ = _decode(result)
if not text and kind in ("missing", "error"):
return f"File not found or unreadable: {relative_path}"
if kind == "image":
return "Cannot search inside images."
lines = text.splitlines()
q = query.lower()
matches = [i for i, line in enumerate(lines) if q in line.lower()]
if not matches:
return f"No matches for '{query}' in {relative_path}."
shown = matches[:_MAX_SEARCH_MATCHES]
snippets: list[str] = []
for i in shown:
start = max(0, i - context_lines)
end = min(len(lines), i + context_lines + 1)
block = "\n".join(f"{n + 1:5d}: {lines[n]}" for n in range(start, end))
snippets.append(block)
header = f"[file={relative_path} matches={len(matches)} showing={len(shown)} query='{query}']"
body = "\n---\n".join(snippets)
return header + "\n" + body
FOLDER_TOOLS = [read_project_folder_file, search_project_folder_file]

View File

@@ -1,13 +1,14 @@
"""Note agent — Markdown note management (list, get, create, update, delete).""" """Note agent — Markdown note management (list, get, create, update, propose edit)."""
from __future__ import annotations from __future__ import annotations
import asyncio
import re import re
from typing import Any from typing import Any
from langchain_core.tools import tool from langchain_core.tools import tool
from app.core.llm import embed from app.core.note_summarizer import generate_note_summary
from app.core.ws_context import execute_on_client from app.core.ws_context import execute_on_client
_UUID_RE = re.compile( _UUID_RE = re.compile(
@@ -19,9 +20,21 @@ def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value)) return bool(_UUID_RE.match(value))
def _fmt_summary(row: dict) -> str:
summary = (row.get("aiSummary") or row.get("ai_summary") or "").strip()
if summary:
return f"{summary}"
snippet = (row.get("content") or "")[:120].replace("\n", " ").strip()
return f"{snippet}" if snippet else ""
@tool @tool
async def list_notes(project_id: str = "") -> str: async def list_notes(project_id: str = "") -> str:
"""List notes, optionally scoped to a project by project_id.""" """List notes with AI summaries, optionally scoped to a project by project_id.
Returns id, title, and ai_summary for each note so you can decide which
note to read in full with get_note before creating or updating.
"""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client( result = await execute_on_client(
action="select", action="select",
@@ -31,7 +44,7 @@ async def list_notes(project_id: str = "") -> str:
rows = result.get("rows", []) rows = result.get("rows", [])
if not rows: if not rows:
return "No notes found." return "No notes found."
lines = [f"- {r['title']} (id: {r['id']})" for r in rows] lines = [f" - [{r['id']}] {r['title']}{_fmt_summary(r)}" for r in rows]
return f"Found {len(rows)} note(s):\n" + "\n".join(lines) return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
@@ -66,14 +79,10 @@ async def create_note(
}, },
) )
row = result["row"] row = result["row"]
# Index the note content in the vector store. note_id: str = row["id"]
vector = await embed(content) # Generate summary asynchronously — fire-and-forget.
await execute_on_client( asyncio.create_task(_refresh_summary(note_id, title, content))
action="vector_upsert", return f"Note created: '{row['title']}' (id: {note_id})."
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note created: '{row['title']}' (id: {row['id']})."
@tool @tool
@@ -82,7 +91,8 @@ async def update_note(
title: str = "", title: str = "",
content: str = "", content: str = "",
) -> str: ) -> str:
"""Update an existing note. Only pass fields that should change. """Update an existing note directly (no approval required).
Use propose_note_edit instead when human review is needed.
note_id: UUID of the note (required) note_id: UUID of the note (required)
If you need to preserve existing content, call get_note first. If you need to preserve existing content, call get_note first.
""" """
@@ -97,17 +107,63 @@ async def update_note(
data={"id": note_id, "updates": updates}, data={"id": note_id, "updates": updates},
) )
row = result["row"] row = result["row"]
# Re-index if content changed.
if content: if content:
vector = await embed(content) new_title = title or row.get("title", "")
await execute_on_client( asyncio.create_task(_refresh_summary(note_id, new_title, content))
action="vector_upsert",
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note updated: '{row['title']}' (id: {row['id']})." return f"Note updated: '{row['title']}' (id: {row['id']})."
@tool
async def propose_note_edit(
note_id: str,
edit_type: str,
proposed_content: str,
reasoning: str = "",
anchor_before: str = "",
anchor_text: str = "",
agent_id: str = "",
run_id: str = "",
) -> str:
"""Propose an AI edit to an existing note, pending human approval.
Use this instead of update_note when review_required is true.
The user will see the proposal highlighted before it is merged.
note_id: UUID of the target note (required)
edit_type: 'append' | 'insert' | 'replace'
- append: adds proposed_content at the end of the note
- insert: inserts proposed_content immediately after anchor_before text
- replace: replaces the first occurrence of anchor_text with proposed_content
proposed_content: the new Markdown text to add or substitute (required)
reasoning: brief explanation shown to the user (recommended)
anchor_before: for 'insert' — the text snippet that precedes the insertion point
anchor_text: for 'replace' — the exact text to be replaced
agent_id: agent identifier (for traceability)
run_id: run identifier (for traceability)
"""
if edit_type not in ("append", "insert", "replace"):
return f"Invalid edit_type '{edit_type}'. Use 'append', 'insert', or 'replace'."
result = await execute_on_client(
action="propose_note_edit",
data={
"noteId": note_id,
"type": edit_type,
"proposedContent": proposed_content,
"reasoning": reasoning or None,
"anchorBefore": anchor_before or None,
"anchorText": anchor_text or None,
"agentId": agent_id or None,
"runId": run_id or None,
},
)
edit_id = result.get("id", "?")
return (
f"Edit proposal created (id: {edit_id}) for note {note_id}. "
f"Status: pending user approval."
)
@tool @tool
async def delete_note(note_id: str) -> str: async def delete_note(note_id: str) -> str:
"""Delete a note permanently by its UUID.""" """Delete a note permanently by its UUID."""
@@ -115,11 +171,32 @@ async def delete_note(note_id: str) -> str:
return f"Note {note_id} deleted." return f"Note {note_id} deleted."
async def _refresh_summary(note_id: str, title: str, content: str) -> None:
"""Generate and persist the AI summary for a note. Fire-and-forget."""
try:
summary = await generate_note_summary(title, content)
if summary:
await execute_on_client(
action="update",
table="notes",
data={
"id": note_id,
"updates": {
"aiSummary": summary,
"aiSummaryUpdatedAt": int(__import__("time").time() * 1000),
},
},
)
except Exception:
pass # fire-and-forget; errors logged by generate_note_summary
NOTE_TOOLS: list[Any] = [ NOTE_TOOLS: list[Any] = [
list_notes, list_notes,
get_note, get_note,
create_note, create_note,
update_note, update_note,
propose_note_edit,
delete_note, delete_note,
] ]

View File

@@ -0,0 +1,63 @@
"""Relations agent — read-only tool wrapping MemoryMiddleware.query_relations."""
from __future__ import annotations
from typing import Any
from langchain_core.tools import tool
from app.core.memory_middleware import MemoryMiddleware
from app.db import async_session
# Injected at tool-factory time by _brief_research_tools(); not a module-level global.
# Each tool closure captures the user_id bound at factory time.
def make_query_relations_tool(user_id: str, trace_id: str | None = None) -> Any:
"""Return a query_relations tool bound to *user_id*."""
@tool
async def query_relations(
subject_label: str = "",
predicate: str = "",
object_label: str = "",
limit: int = 10,
) -> str:
"""Query the relational memory graph for entity relationships.
Returns rows where subject ↔ predicate ↔ object match the given filters.
All parameters are optional — omit to retrieve all relations up to limit.
subject_label: entity label on the left side (e.g. a client name, "Acme Corp").
predicate: relationship type (e.g. "mentioned_in", "works_at", "related_to").
object_label: entity label on the right side (e.g. a project name, "Website Redesign").
limit: max rows to return (default 10).
"""
import logging
logger = logging.getLogger(__name__)
logger.info(
"relations_agent: query_relations trace=%s user=%s subject=%r predicate=%r object=%r",
trace_id or "-", user_id, subject_label, predicate, object_label,
)
async with async_session() as db:
memory = MemoryMiddleware(db)
rows = await memory.query_relations(
user_id=user_id,
subject=subject_label or None,
predicate=predicate or None,
object_=object_label or None,
limit=limit,
)
if not rows:
return "No relational memory entries found for the given filters."
lines = [
f"- {r.subject_label} —[{r.predicate}]→ {r.object_label}"
+ (f" (confidence: {r.confidence:.2f})" if r.confidence is not None else "")
for r in rows
]
return f"Found {len(rows)} relation(s):\n" + "\n".join(lines)
return query_relations

View File

@@ -1,232 +0,0 @@
"""Agent routes.
Backend responsibilities are intentionally minimal:
GET /agents/catalog — static catalog for UI display
POST /agents/can-create — billing eligibility check
POST /agents/trigger — trigger a local agent run
Agent configuration is owned by the Electron app and is not persisted
in backend agent-config tables.
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import FEATURES
from app.core.agent_runner import is_agent_running, run_local_agent
from app.core.device_manager import device_manager
from app.db import get_session
from app.models import AgentRunLog, LocalAgentConfig
from app.schemas import (
AgentCatalogItem,
AgentCreationCheckRequest,
AgentCreationCheckResponse,
AgentRunLogResponse,
AgentTriggerRequest,
UserProfile,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/agents", tags=["agents"])
# ── Datetime helpers ──────────────────────────────────────────────────
def _dt_ms(dt: datetime) -> int:
return int(dt.timestamp() * 1000)
def _dt_ms_opt(dt: datetime | None) -> int | None:
return int(dt.timestamp() * 1000) if dt else None
def _to_data_types(values: list[str]) -> list[str]:
normalize = {
"task": "tasks", "tasks": "tasks",
"note": "notes", "notes": "notes",
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
"project": "projects", "projects": "projects",
}
seen: set[str] = set()
result: list[str] = []
for v in values:
mapped = normalize.get(v)
if mapped and mapped not in seen:
seen.add(mapped)
result.append(mapped)
return result
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
return AgentRunLogResponse(
id=log.id,
agent_id=log.agent_id,
agent_type=log.agent_type, # type: ignore[arg-type]
status=log.status, # type: ignore[arg-type]
items_processed=log.items_processed,
items_created=log.items_created,
errors=log.errors or [],
started_at=_dt_ms(log.started_at),
completed_at=_dt_ms_opt(log.completed_at),
)
def _enforce_agent_limit(tier: str, current_count: int) -> int:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
if limit != -1 and current_count >= limit:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
)
return limit
async def _enforce_run_frequency(
tier: str,
user_id: str,
db: AsyncSession,
) -> None:
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
if limit == -1:
return # unlimited
today_start = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
)
result = await db.execute(
select(func.count(AgentRunLog.id)).where(
AgentRunLog.user_id == user_id,
AgentRunLog.started_at >= today_start,
)
)
runs_today: int = result.scalar_one()
if runs_today >= limit:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
)
# ── Catalog ───────────────────────────────────────────────────────────
@router.get("/catalog", response_model=list[AgentCatalogItem])
async def get_agent_catalog(
current_user: UserProfile = Depends(get_current_user),
) -> list[AgentCatalogItem]:
"""Return the static list of available agent types and their descriptions."""
return [
AgentCatalogItem(
type="local_directory",
name="Local Directory Monitor",
description="Watches local directories, extracts data from files using AI",
),
AgentCatalogItem(
type="gmail",
name="Gmail Connector",
description="Scans Gmail inbox, extracts tasks/notes from emails",
),
AgentCatalogItem(
type="teams",
name="Microsoft Teams Connector",
description="Monitors Teams messages, extracts action items",
),
AgentCatalogItem(
type="outlook",
name="Outlook Connector",
description="Scans Outlook inbox, extracts tasks/notes",
),
]
@router.post("/can-create", response_model=AgentCreationCheckResponse)
async def can_create_agent(
body: AgentCreationCheckRequest,
current_user: UserProfile = Depends(get_current_user),
) -> AgentCreationCheckResponse:
"""Check if the user can create one more agent based on billing tier.
Since configuration is client-owned, the Electron app sends its current
active agent count and the backend applies tier limits.
"""
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
allowed = limit == -1 or body.active_agents < limit
return AgentCreationCheckResponse(
allowed=allowed,
tier=current_user.tier,
active_agents=body.active_agents,
limit=limit,
)
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
async def trigger_agent_run(
body: AgentTriggerRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> AgentRunLogResponse:
"""Trigger a local agent run using client-provided configuration."""
_enforce_agent_limit(current_user.tier, body.active_agents)
await _enforce_run_frequency(current_user.tier, current_user.id, db)
last_run_dt = (
datetime.fromtimestamp(body.last_run_at / 1000, tz=timezone.utc)
if body.last_run_at
else None
)
config = LocalAgentConfig(
id=str(uuid.uuid4()),
user_id=current_user.id,
device_id=body.device_id,
name="Local Directory Monitor",
directory_paths=[body.directory],
data_types=_to_data_types(body.what_to_extract),
prompt_template=body.custom_agent_prompt or "",
agent_config=body.agent_config,
file_extensions=[],
schedule_cron=body.batch_interval,
enabled=True,
last_run_at=last_run_dt,
)
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
stable_agent_id = body.agent_id or config.id
if is_agent_running(stable_agent_id):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Agent is already running. Only one run per agent is allowed at a time.",
)
run_log = AgentRunLog(
agent_id=stable_agent_id,
agent_type="local",
user_id=current_user.id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_context = {
"type": "agent_batch",
"run_id": run_log.id,
"agent_id": stable_agent_id,
}
asyncio.create_task(
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
)
return _to_run_log_response(run_log)

View File

@@ -9,7 +9,7 @@ from __future__ import annotations
from typing import Any from typing import Any
from fastapi import APIRouter, Depends, Header, Request, status from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -96,3 +96,37 @@ async def list_invoices(
""" """
invoices = await stripe_service.list_invoices(current_user.id, db) invoices = await stripe_service.list_invoices(current_user.id, db)
return invoices return invoices
# ── Quota check ────────────────────────────────────────────────────────
from app.billing.quota import check_folder_quota, QuotaExceeded # noqa: E402
class QuotaCheckRequest(BaseModel):
feature: str
estimated_files: int
@router.post("/quota/check")
async def quota_check(
payload: QuotaCheckRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict:
"""Pre-flight folder quota check. 402 if tier limits would be exceeded."""
if payload.feature != "folder_index":
raise HTTPException(status_code=400, detail="Unknown feature")
try:
await check_folder_quota(
user_id=current_user.id,
tier=current_user.tier,
estimated_files=payload.estimated_files,
db=db,
)
except QuotaExceeded as exc:
raise HTTPException(
status_code=402,
detail={"reason": exc.reason, "message": str(exc)},
)
return {"ok": True}

View File

@@ -9,7 +9,7 @@ available during the WebSocket handshake).
Protocol: Protocol:
1. Client connects → JWT validated → connection accepted. 1. Client connects → JWT validated → connection accepted.
2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``. 2. Client sends ``device_hello`` frame: ``{ type, device_id, scout_ids }``.
3. Backend registers the connection in ``DeviceConnectionManager``. 3. Backend registers the connection in ``DeviceConnectionManager``.
4. Session enters message dispatch loop + heartbeat. 4. Session enters message dispatch loop + heartbeat.
@@ -39,23 +39,31 @@ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from jose import JWTError, jwt from jose import JWTError, jwt
from sqlalchemy import update from sqlalchemy import update
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start from app.api.routes.scout_setup import handle_journey_message, handle_journey_start
from app.config.settings import settings from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs from app.scouts.engine import ScoutEngine
from app.core.scout_runner import trigger_pending_runs
from app.core.scout_session_buffer import session_buffer
from app.core.brief_agent import run_home_brief, run_project_brief from app.core.brief_agent import run_home_brief, run_project_brief
from app.core.deep_agent import run_floating_stream, run_home_stream from app.core.deep_agent import run_contextual_stream, run_home_stream, run_task_brief_research_stream
from app.core.output_formatter import extract_canvas_block
from app.core.device_manager import device_manager from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware from app.core.memory_middleware import MemoryMiddleware
from app.core.output_formatter import StreamFormatter from app.core.output_formatter import StreamFormatter
from app.core.ws_context import clear_client_executor, set_client_executor from app.core.ws_context import clear_client_executor, set_client_executor
from app.db import async_session from app.db import async_session
from app.models import AgentRunLog from app.models import ScoutRunLog
from app.schemas import WsFrameType, WsStreamEnd from app.schemas import WsFrameType, WsStreamEnd
from app.schemas.contextual import ContextualScope, render_scope_block
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ws", tags=["device-ws"]) router = APIRouter(prefix="/ws", tags=["device-ws"])
# ── v7 folder index session state ─────────────────────────────────────
# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled }
_index_sessions: dict[str, dict] = {}
_HEARTBEAT_INTERVAL = 30 # seconds _HEARTBEAT_INTERVAL = 30 # seconds
_PONG_TIMEOUT = 10 # seconds — grace window after a ping _PONG_TIMEOUT = 10 # seconds — grace window after a ping
@@ -93,7 +101,7 @@ async def device_ws(websocket: WebSocket) -> None:
if hello.get("type") != WsFrameType.device_hello: if hello.get("type") != WsFrameType.device_hello:
raise ValueError("expected device_hello as first frame") raise ValueError("expected device_hello as first frame")
device_id: str = hello["device_id"] device_id: str = hello["device_id"]
agent_ids: list[str] = hello.get("agent_ids", []) scout_ids: list[str] = hello.get("scout_ids", [])
except (KeyError, ValueError, json.JSONDecodeError) as exc: except (KeyError, ValueError, json.JSONDecodeError) as exc:
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc) logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
await websocket.close(code=1008) await websocket.close(code=1008)
@@ -102,15 +110,25 @@ async def device_ws(websocket: WebSocket) -> None:
# ── 3. Register connection ──────────────────────────────────────── # ── 3. Register connection ────────────────────────────────────────
device_manager.register(user_id, device_id, websocket) device_manager.register(user_id, device_id, websocket)
logger.info( logger.info(
"device_ws: connected user=%s device=%s agents=%s", "device_ws: connected user=%s device=%s scouts=%s",
user_id, user_id,
device_id, device_id,
agent_ids, scout_ids,
) )
# Trigger any overdue agent runs now that the device is connected. # Trigger any overdue agent runs now that the device is connected.
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager)) asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
# Drain any queued scout proposals and deliver to the client (non-blocking).
async def _deliver_pending_safe() -> None:
import uuid as _uuid # noqa: PLC0415
try:
await ScoutEngine().deliver_pending(_uuid.UUID(user_id), websocket)
except Exception:
logger.exception("scout deliver_pending failed for user %s", user_id)
asyncio.create_task(_deliver_pending_safe())
# ── 4. Concurrent message loop + heartbeat ──────────────────────── # ── 4. Concurrent message loop + heartbeat ────────────────────────
try: try:
await asyncio.gather( await asyncio.gather(
@@ -154,16 +172,16 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
_handle_home_request(websocket, user_id, frame) _handle_home_request(websocket, user_id, frame)
) )
elif frame_type == WsFrameType.floating_request:
asyncio.create_task(
_handle_floating_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.brief_request: elif frame_type == WsFrameType.brief_request:
asyncio.create_task( asyncio.create_task(
_handle_brief_request(websocket, user_id, frame) _handle_brief_request(websocket, user_id, frame)
) )
elif frame_type == WsFrameType.task_brief_request:
asyncio.create_task(
_handle_task_brief_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.journey_start: elif frame_type == WsFrameType.journey_start:
asyncio.create_task( asyncio.create_task(
_handle_journey_start(websocket, user_id, frame) _handle_journey_start(websocket, user_id, frame)
@@ -174,6 +192,37 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
_handle_journey_message(websocket, user_id, frame) _handle_journey_message(websocket, user_id, frame)
) )
elif frame_type == WsFrameType.index_session_start:
asyncio.create_task(
_handle_index_session_start(websocket, user_id, frame)
)
elif frame_type == WsFrameType.index_file_batch:
asyncio.create_task(
_handle_index_file_batch(websocket, user_id, frame)
)
elif frame_type == WsFrameType.index_session_cancel:
await _handle_index_session_cancel(websocket, frame)
elif frame_type == WsFrameType.contextual_request:
asyncio.create_task(
_handle_contextual_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.contextual_scope_update:
asyncio.create_task(
_handle_contextual_scope_update(websocket, user_id, frame)
)
elif frame_type == "scout_proposal_ack":
proposal_id = frame.get("proposal_id")
if proposal_id:
try:
await ScoutEngine().ack_proposal(proposal_id)
except Exception:
logger.exception("scout ack_proposal failed for %s", proposal_id)
elif frame_type == "pong": elif frame_type == "pong":
# Heartbeat ack — nothing to do, connection is alive. # Heartbeat ack — nothing to do, connection is alive.
pass pass
@@ -205,11 +254,13 @@ async def _handle_home_request(
request_id = frame.get("request_id") or str(uuid4()) request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "") message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4()) session_id: str = frame.get("session_id") or str(uuid4())
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
logger.info( logger.info(
"device_ws: home_request_start user=%s req=%s session=%s msg=%s", "device_ws: home_request_start user=%s req=%s session=%s project=%s msg=%s",
user_id, user_id,
request_id, request_id,
session_id, session_id,
project_id,
message[:200], message[:200],
) )
@@ -234,7 +285,7 @@ async def _handle_home_request(
set_client_executor(executor) set_client_executor(executor)
response_chunks: list[str] = [] response_chunks: list[str] = []
try: try:
event_stream = run_home_stream(user_id, message, context) event_stream = run_home_stream(user_id, message, context, project_id=project_id)
formatter = StreamFormatter(request_id=request_id) formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream): async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json()) await websocket.send_text(ws_frame.model_dump_json())
@@ -264,26 +315,41 @@ async def _handle_home_request(
) )
async def _handle_floating_request( # ── v8 Contextual Sidebar Handlers ───────────────────────────────────
def get_session_buffer(user_id: str, session_id: str, channel: str = "contextual"):
"""Return a session-scoped buffer proxy for the given user+session.
Returns a _ContextualBufferProxy that exposes append_system_message().
Defined at module level so tests can monkeypatch it.
The channel kwarg is accepted for forward-compatibility.
"""
from app.core.scout_session_buffer import ContextualBufferProxy # noqa: PLC0415
return ContextualBufferProxy(session_buffer, user_id, session_id)
async def _handle_contextual_request(
websocket: WebSocket, websocket: WebSocket,
user_id: str, user_id: str,
frame: dict, frame: dict,
) -> None: ) -> None:
"""Handle a floating_request frame — streams FloatingFormatter output back on the socket.""" """Handle a contextual_request frame — runs the contextual agent and streams frames."""
request_id = frame.get("request_id") or str(uuid4()) request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "") message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4()) session_id: str = frame.get("session_id") or str(uuid4())
scope: dict = frame.get("scope", {}) scope_payload: dict = frame.get("scope", {})
logger.info( logger.info(
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s", "device_ws: contextual_request_start user=%s req=%s session=%s msg=%s",
user_id, user_id,
request_id, request_id,
session_id, session_id,
json.dumps(scope, ensure_ascii=True)[:200],
message[:200], message[:200],
) )
# ── Memory: enrich context before LLM call ──────────────────────── scope = ContextualScope.model_validate(scope_payload)
# Enrich context with memory before the LLM call.
async with async_session() as db: async with async_session() as db:
memory = MemoryMiddleware(db) memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context( memory_context = await memory.enrich_context(
@@ -295,9 +361,8 @@ async def _handle_floating_request(
context: dict = { context: dict = {
"conversation_history": frame.get("conversation_history", []), "conversation_history": frame.get("conversation_history", []),
"scope": scope,
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
"format_prefs": frame.get("format_prefs"), "format_prefs": frame.get("format_prefs"),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context, **memory_context,
} }
@@ -305,7 +370,12 @@ async def _handle_floating_request(
set_client_executor(executor) set_client_executor(executor)
response_chunks: list[str] = [] response_chunks: list[str] = []
try: try:
event_stream = run_floating_stream(user_id, message, context) event_stream = run_contextual_stream(
user_id=user_id,
message=message,
context=context,
scope=scope,
)
formatter = StreamFormatter(request_id=request_id) formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream): async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json()) await websocket.send_text(ws_frame.model_dump_json())
@@ -313,20 +383,20 @@ async def _handle_floating_request(
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr] response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc: except Exception as exc:
logger.error( logger.error(
"device_ws: floating_request failed user=%s req=%s: %s", "device_ws: contextual_request failed user=%s req=%s: %s",
user_id, request_id, exc, user_id, request_id, exc,
) )
finally: finally:
clear_client_executor() clear_client_executor()
# ── Memory: store episode after response ────────────────────────── # Store episode so the contextual agent can recall prior turns.
async with async_session() as db: async with async_session() as db:
memory = MemoryMiddleware(db) memory = MemoryMiddleware(db)
await memory.store_episode( await memory.store_episode(
user_id, session_id, message, "".join(response_chunks), trace_id=request_id user_id, session_id, message, "".join(response_chunks), trace_id=request_id
) )
logger.info( logger.info(
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d", "device_ws: contextual_request_end user=%s req=%s session=%s response_chars=%d",
user_id, user_id,
request_id, request_id,
session_id, session_id,
@@ -334,6 +404,33 @@ async def _handle_floating_request(
) )
async def _handle_contextual_scope_update(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a contextual_scope_update frame.
Injects a synthetic system message into the session buffer so the next
agent turn knows the user navigated. No LLM call is made.
"""
session_id: str = frame.get("session_id") or str(uuid4())
scope = ContextualScope.model_validate(frame.get("scope", {}))
block = render_scope_block(scope)
buf = get_session_buffer(user_id, session_id, channel="contextual")
buf.append_system_message(
f"User navigated to a new view. {block} Treat this as the new active context."
)
await websocket.send_text(json.dumps({
"type": WsFrameType.contextual_scope_ack,
"session_id": session_id,
}))
logger.info(
"device_ws: contextual_scope_update user=%s session=%s page=%s",
user_id, session_id, scope.page,
)
async def _handle_brief_request( async def _handle_brief_request(
websocket: WebSocket, websocket: WebSocket,
user_id: str, user_id: str,
@@ -415,6 +512,98 @@ async def _handle_brief_request(
) )
# ── v6 Task Brief Handler ────────────────────────────────────────────
async def _handle_task_brief_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a task_brief_request frame — Stage-1 executive assistant deep research.
Streams the briefing markdown back to the client.
On stream_end, emits a ``canvas_draft`` mutation if the agent produced one.
"""
request_id = frame.get("request_id") or str(uuid4())
session_id = frame.get("session_id") or str(uuid4())
task_id: str = frame.get("task_id") or frame.get("taskId") or ""
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
logger.info(
"device_ws: task_brief_request_start user=%s req=%s task=%s project=%s [cache_miss]",
user_id, request_id, task_id, project_id,
)
if not task_id:
await websocket.send_text(
WsStreamEnd(request_id=request_id, error="task_id is required").model_dump_json()
)
return
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id,
f"task brief: {task_id}",
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
"format_prefs": frame.get("format_prefs"),
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_task_brief_research_stream(user_id, task_id, context, project_id=project_id)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
await websocket.send_text(ws_frame.model_dump_json())
elif ws_frame.type == "stream_start":
await websocket.send_text(ws_frame.model_dump_json())
# stream_end is emitted below with mutations — skip formatter's version
except Exception as exc:
logger.error(
"device_ws: task_brief_request failed user=%s req=%s task=%s: %s",
user_id, request_id, task_id, exc,
)
await websocket.send_text(
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
)
return
finally:
clear_client_executor()
# Extract canvas block then emit stream_end with optional mutations.
full_response = "".join(response_chunks)
_visible, canvas_content, canvas_kind = extract_canvas_block(full_response)
mutations: list[dict] = []
if canvas_content:
mutations.append({
"type": "canvas_draft",
"content": canvas_content,
"kind": canvas_kind,
})
await websocket.send_text(
WsStreamEnd(request_id=request_id, mutations=mutations or None).model_dump_json()
)
logger.info(
"device_ws: task_brief_request_end user=%s req=%s task=%s response_chars=%d canvas=%s",
user_id, request_id, task_id, len(full_response), canvas_kind or "none",
)
# ── v4 Journey Handlers ───────────────────────────────────────────── # ── v4 Journey Handlers ─────────────────────────────────────────────
@@ -472,6 +661,174 @@ async def _handle_journey_message(
clear_client_executor() clear_client_executor()
# ── v7 Folder Index Handlers ──────────────────────────────────────────
async def _handle_index_session_start(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Register a new folder index session. No response sent — client is declaring intent."""
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
project_id: str | None = frame.get("projectId") or frame.get("project_id")
total: int = int(frame.get("totalFiles") or frame.get("total_files") or 0)
if not session_id:
logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id)
return
_index_sessions[session_id] = {
"user_id": user_id,
"project_id": project_id,
"processed": 0,
"total": total,
"cancelled": False,
}
logger.info(
"device_ws: index_session_start user=%s session=%s project=%s total=%d",
user_id, session_id, project_id, total,
)
async def _handle_index_session_cancel(
websocket: WebSocket,
frame: dict,
) -> None:
"""Mark a session as cancelled and emit index_session_done(cancelled)."""
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
session = _index_sessions.get(session_id)
if session:
session["cancelled"] = True
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "cancelled",
}))
_index_sessions.pop(session_id, None)
logger.info("device_ws: index_session_cancel session=%s", session_id)
async def _handle_index_file_batch(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Process a batch of files for an index session, streaming results back."""
# Lazy imports to avoid heavy load at module startup.
from app.core.folder_indexer import ( # noqa: PLC0415
summarize_image,
summarize_pdf,
summarize_docx,
summarize_text,
)
from app.billing.tier_manager import tier_manager # noqa: PLC0415
from app.billing.quota import add_token_usage # noqa: PLC0415
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
files: list[dict] = frame.get("files", [])
session = _index_sessions.get(session_id)
if not session or session.get("cancelled"):
return
async with async_session() as db:
tier = await tier_manager.get_tier(user_id, db)
raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens")
cap: int | None = None if raw_cap == -1 else raw_cap
for file_info in files:
if session.get("cancelled"):
return
# Electron's toSnakeCase converts payload keys, so accept both forms.
rel_path: str = file_info.get("relPath") or file_info.get("rel_path") or ""
kind: str = file_info.get("kind") or "text"
content: str = file_info.get("content") or ""
ext: str = file_info.get("ext") or ""
mime: str = file_info.get("mime") or "application/octet-stream"
name: str = rel_path.split("/")[-1] or rel_path
try:
if kind == "image":
res = await summarize_image(image_b64=content, mime=mime)
elif kind == "pdf":
res = await summarize_pdf(pdf_b64=content, name=name)
elif kind == "docx":
res = await summarize_docx(docx_b64=content, name=name)
else:
res = await summarize_text(content=content, ext=ext, name=name)
except Exception as exc:
logger.warning(
"device_ws: index_file_batch summarize failed session=%s path=%s: %s",
session_id, rel_path, exc,
)
await websocket.send_text(json.dumps({
"type": WsFrameType.index_file_result,
"sessionId": session_id,
"relPath": rel_path,
"summary": None,
"tokensUsed": 0,
"error": str(exc),
}))
session["processed"] += 1
continue
# Account for token usage and check cap.
usage = await add_token_usage(
user_id=user_id,
feature="folder_index",
tokens=res.tokens_used,
db=db,
cap=cap,
)
await websocket.send_text(json.dumps({
"type": WsFrameType.index_file_result,
"sessionId": session_id,
"relPath": rel_path,
"summary": res.summary,
"tokensUsed": res.tokens_used,
}))
session["processed"] += 1
if usage.exhausted:
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "quota_exceeded",
}))
_index_sessions.pop(session_id, None)
logger.info(
"device_ws: index_session quota_exceeded user=%s session=%s",
user_id, session_id,
)
return
# After processing the batch, emit progress.
processed = session["processed"]
total = session["total"]
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_progress,
"sessionId": session_id,
"processed": processed,
"total": total,
}))
if processed >= total:
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "completed",
}))
_index_sessions.pop(session_id, None)
logger.info(
"device_ws: index_session_done completed user=%s session=%s processed=%d",
user_id, session_id, processed,
)
# ── Heartbeat ───────────────────────────────────────────────────────── # ── Heartbeat ─────────────────────────────────────────────────────────
async def _heartbeat_loop(websocket: WebSocket) -> None: async def _heartbeat_loop(websocket: WebSocket) -> None:
@@ -484,14 +841,14 @@ async def _heartbeat_loop(websocket: WebSocket) -> None:
# ── Disconnect cleanup ──────────────────────────────────────────────── # ── Disconnect cleanup ────────────────────────────────────────────────
async def _mark_runs_disconnected(user_id: str) -> None: async def _mark_runs_disconnected(user_id: str) -> None:
"""Mark all in-progress AgentRunLog rows as 'error' for this user.""" """Mark all in-progress ScoutRunLog rows as 'error' for this user."""
try: try:
async with async_session() as db: async with async_session() as db:
await db.execute( await db.execute(
update(AgentRunLog) update(ScoutRunLog)
.where( .where(
AgentRunLog.user_id == user_id, ScoutRunLog.user_id == user_id,
AgentRunLog.status == "running", ScoutRunLog.status == "running",
) )
.values( .values(
status="error", status="error",

View File

@@ -1,4 +1,4 @@
"""Chatbot Journey — WS-based guided conversation to build an AgentConfig. """Chatbot Journey — WS-based guided conversation to build an ScoutConfig.
The journey is driven entirely through WebSocket frames (no REST endpoints). The journey is driven entirely through WebSocket frames (no REST endpoints).
The device WS handler dispatches ``journey_start`` and ``journey_message`` The device WS handler dispatches ``journey_start`` and ``journey_message``
@@ -13,7 +13,7 @@ Journey flow:
3. FE sends ``journey_message`` frames for each user reply. 3. FE sends ``journey_message`` frames for each user reply.
4. Server appends the user message, calls the LLM (which may read files 4. Server appends the user message, calls the LLM (which may read files
via tools), and sends back a ``journey_reply``. via tools), and sends back a ``journey_reply``.
5. After 3-5 turns the LLM wraps up by emitting an ``AgentConfig`` JSON 5. After 3-5 turns the LLM wraps up by emitting an ``ScoutConfig`` JSON
block delimited by ``AGENT_CONFIG_START`` / ``AGENT_CONFIG_END``. block delimited by ``AGENT_CONFIG_START`` / ``AGENT_CONFIG_END``.
6. Server parses and validates the JSON with Pydantic, sends 6. Server parses and validates the JSON with Pydantic, sends
``journey_reply`` with ``done=True`` and the serialised config. ``journey_reply`` with ``done=True`` and the serialised config.
@@ -34,7 +34,7 @@ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, Tool
from app.agents.filesystem_agent import make_directory_tools from app.agents.filesystem_agent import make_directory_tools
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
from app.core.llm import get_agent_llm, model_for_agent from app.core.llm import get_agent_llm, model_for_agent
from app.schemas import AgentConfig from app.schemas import ScoutConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes _SESSION_TTL_SECONDS: int = 1800 # 30 minutes
# Sentinel strings used to delimit the LLM-produced AgentConfig JSON. # Sentinel strings used to delimit the LLM-produced ScoutConfig JSON.
_CONFIG_START = "AGENT_CONFIG_START" _CONFIG_START = "AGENT_CONFIG_START"
_CONFIG_END = "AGENT_CONFIG_END" _CONFIG_END = "AGENT_CONFIG_END"
@@ -92,7 +92,7 @@ def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
_JOURNEY_SYSTEM_PROMPT = """\ _JOURNEY_SYSTEM_PROMPT = """\
You are a friendly assistant helping a freelancer configure a data-extraction agent. You are a friendly assistant helping a freelancer configure a data-extraction agent.
Your job is to understand what files the user has in their directory and produce a Your job is to understand what files the user has in their directory and produce a
structured AgentConfig JSON that the extraction agent will use as its instruction set. structured ScoutConfig JSON that the extraction agent will use as its instruction set.
You have access to file-system tools to explore the user's directory: You have access to file-system tools to explore the user's directory:
- list_directory: see folder structure and file names - list_directory: see folder structure and file names
@@ -122,7 +122,7 @@ Cover these topics based on what you discovered:
4. Date extraction (e.g. "by Friday" dueDate) 4. Date extraction (e.g. "by Friday" dueDate)
5. Exclusion rules (e.g. skip newsletters, skip files with no project match) 5. Exclusion rules (e.g. skip newsletters, skip files with no project match)
### Step 4 — Produce the AgentConfig JSON ### Step 4 — Produce the ScoutConfig JSON
Once you are 90% confident, output the final config between these exact markers Once you are 90% confident, output the final config between these exact markers
(each on its own line): (each on its own line):
@@ -168,7 +168,7 @@ def _build_system_prompt(
) -> tuple[str, Any]: ) -> tuple[str, Any]:
"""Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``.""" """Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``."""
existing_section = ( existing_section = (
"\nThe user already has the following AgentConfig — refine it based on their answers:\n" "\nThe user already has the following ScoutConfig — refine it based on their answers:\n"
f"```json\n{existing_config}\n```\n" f"```json\n{existing_config}\n```\n"
if existing_config if existing_config
else "" else ""
@@ -189,11 +189,11 @@ def _build_system_prompt(
return compiled, prompt_obj return compiled, prompt_obj
# ── AgentConfig extraction ──────────────────────────────────────────────── # ── ScoutConfig extraction ────────────────────────────────────────────────
def _extract_agent_config(text: str) -> str | None: def _extract_agent_config(text: str) -> str | None:
"""Return validated AgentConfig JSON string from between markers, or None. """Return validated ScoutConfig JSON string from between markers, or None.
Parses the JSON with Pydantic to ensure it conforms to the schema before Parses the JSON with Pydantic to ensure it conforms to the schema before
returning. Returns None if markers are absent or JSON is invalid. returning. Returns None if markers are absent or JSON is invalid.
@@ -206,10 +206,10 @@ def _extract_agent_config(text: str) -> str | None:
if not raw: if not raw:
return None return None
try: try:
parsed = AgentConfig.model_validate_json(raw) parsed = ScoutConfig.model_validate_json(raw)
return parsed.model_dump_json() return parsed.model_dump_json()
except Exception as exc: except Exception as exc:
logger.warning("agent_setup: failed to parse AgentConfig JSON: %s", exc) logger.warning("agent_setup: failed to parse ScoutConfig JSON: %s", exc)
return None return None
@@ -475,7 +475,7 @@ async def handle_journey_message(
if turns >= _MAX_TURNS: if turns >= _MAX_TURNS:
nudge_content = ( nudge_content = (
"[System: You have enough information. Please generate the final " "[System: You have enough information. Please generate the final "
f"AgentConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]" f"ScoutConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]"
) )
session.history.append({"role": "user", "content": nudge_content}) session.history.append({"role": "user", "content": nudge_content})

View File

@@ -0,0 +1,120 @@
"""Gmail Pub/Sub push receiver.
Google Pub/Sub push subscriptions deliver Gmail watch notifications as POST
requests with a JSON envelope. The body payload contains a base64-encoded
JSON blob with ``emailAddress`` + ``historyId``. We resolve the user by
email, look up their cloud_scout_configs row for provider='gmail', and
hand off to ScoutEngine.trigger_scout.
Authentication: Pub/Sub push includes an OIDC JWT in the Authorization
header. We verify it against Google's public keys with the audience
configured in our Pub/Sub subscription.
Dev mode: when ``GMAIL_PUBSUB_AUDIENCE`` is empty, JWT verification is
skipped and a warning is logged. Production must set this env var.
"""
from __future__ import annotations
import base64
import json
import logging
import uuid
from fastapi import APIRouter, Header, HTTPException, Request, status
from sqlalchemy import select
from app.config.settings import settings
from app.db import async_session
from app.models import CloudScoutConfig, User
from app.scouts.engine import ScoutEngine
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/scouts/webhooks", tags=["scout-webhooks"])
def _verify_pubsub_jwt(token: str) -> bool:
"""Verify the Google Pub/Sub OIDC JWT.
Returns True when valid, False on any verification failure.
Dev skip: if ``settings.GMAIL_PUBSUB_AUDIENCE`` is empty, logs a
warning and returns True so local development works without a real
Pub/Sub subscription. Production must configure the audience.
"""
if not token:
return False
if not settings.GMAIL_PUBSUB_AUDIENCE:
logger.warning(
"GMAIL_PUBSUB_AUDIENCE not set — skipping Pub/Sub JWT verification (dev mode only)"
)
return True
try:
from google.auth.transport import requests as g_requests # noqa: PLC0415
from google.oauth2 import id_token # noqa: PLC0415
id_token.verify_oauth2_token(
token,
g_requests.Request(),
audience=settings.GMAIL_PUBSUB_AUDIENCE,
)
return True
except Exception:
logger.warning("pubsub jwt verification failed", exc_info=True)
return False
@router.post("/gmail", status_code=status.HTTP_204_NO_CONTENT)
async def gmail_pubsub(
request: Request,
authorization: str = Header(default=""),
) -> None:
"""Receive a Gmail Pub/Sub push notification.
Verifies the OIDC JWT, decodes the Pub/Sub envelope, resolves the user
by email, and triggers ScoutEngine.trigger_scout for each enabled Gmail
scout belonging to that user.
Returns 204 No Content on success (including benign no-ops like unknown
email or empty message data). Returns 401 on JWT verification failure.
"""
token = authorization.removeprefix("Bearer ").strip()
if not _verify_pubsub_jwt(token):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid Pub/Sub JWT")
body = await request.json()
msg = body.get("message") or {}
raw = msg.get("data")
if not raw:
return # ack without action — empty message data
try:
decoded = json.loads(base64.b64decode(raw).decode())
except Exception:
logger.warning("pubsub payload decode failed")
return
email = decoded.get("emailAddress")
if not email:
return
async with async_session() as session:
user_q = await session.execute(select(User).where(User.email == email))
user = user_q.scalar_one_or_none()
if user is None:
logger.info("pubsub: no user for %s — ignoring", email)
return
scouts_q = await session.execute(
select(CloudScoutConfig).where(
CloudScoutConfig.user_id == user.id,
CloudScoutConfig.provider == "gmail",
CloudScoutConfig.enabled == True, # noqa: E712
)
)
scouts = scouts_q.scalars().all()
engine = ScoutEngine()
for scout in scouts:
await engine.trigger_scout(uuid.UUID(str(scout.id)))

440
app/api/routes/scouts.py Normal file
View File

@@ -0,0 +1,440 @@
"""Scout routes.
Backend responsibilities are intentionally minimal:
GET /scouts/catalog — static catalog for UI display
POST /scouts/can-create — billing eligibility check
POST /scouts/trigger — trigger a local scout run
Scout configuration is owned by the Electron app and is not persisted
in backend scout-config tables.
Gmail OAuth setup (scout-specific consent):
GET /scouts/oauth/gmail/authorize — returns consent-screen URL
GET /scouts/oauth/gmail/web-callback — bounces to deep link (excluded from schema)
POST /scouts/oauth/gmail/callback — exchanges code, stores encrypted token
"""
from __future__ import annotations
import asyncio
import logging
import secrets
import time
import urllib.parse
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import RedirectResponse
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.auth.oauth_providers import generate_pkce_pair
from app.billing.tier_manager import FEATURES
from app.config.settings import settings
from app.core.scout_runner import is_agent_running, run_local_agent
from app.core.device_manager import device_manager
from app.core.note_summarizer import generate_note_summary
from app.db import get_session
from app.integrations import encrypt_token
from app.models import CloudScoutConfig, ScoutRunLog, LocalScoutConfig
from app.schemas import (
ScoutCatalogItem,
ScoutCreationCheckRequest,
ScoutCreationCheckResponse,
ScoutRunLogResponse,
ScoutTriggerRequest,
UserProfile,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/scouts", tags=["scouts"])
# ── Datetime helpers ──────────────────────────────────────────────────
def _dt_ms(dt: datetime) -> int:
return int(dt.timestamp() * 1000)
def _dt_ms_opt(dt: datetime | None) -> int | None:
return int(dt.timestamp() * 1000) if dt else None
def _to_data_types(values: list[str]) -> list[str]:
normalize = {
"task": "tasks", "tasks": "tasks",
"note": "notes", "notes": "notes",
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
"project": "projects", "projects": "projects",
}
seen: set[str] = set()
result: list[str] = []
for v in values:
mapped = normalize.get(v)
if mapped and mapped not in seen:
seen.add(mapped)
result.append(mapped)
return result
def _to_run_log_response(log: ScoutRunLog) -> ScoutRunLogResponse:
return ScoutRunLogResponse(
id=log.id,
agent_id=log.scout_id,
agent_type=log.scout_type, # type: ignore[arg-type]
status=log.status, # type: ignore[arg-type]
items_processed=log.items_processed,
items_created=log.items_created,
errors=log.errors or [],
started_at=_dt_ms(log.started_at),
completed_at=_dt_ms_opt(log.completed_at),
)
def _enforce_agent_limit(tier: str, current_count: int) -> int:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
if limit != -1 and current_count >= limit:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
)
return limit
async def _enforce_run_frequency(
tier: str,
user_id: str,
db: AsyncSession,
) -> None:
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
if limit == -1:
return # unlimited
today_start = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
)
result = await db.execute(
select(func.count(ScoutRunLog.id)).where(
ScoutRunLog.user_id == user_id,
ScoutRunLog.started_at >= today_start,
)
)
runs_today: int = result.scalar_one()
if runs_today >= limit:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
)
# ── Catalog ───────────────────────────────────────────────────────────
@router.get("/catalog", response_model=list[ScoutCatalogItem])
async def get_agent_catalog(
current_user: UserProfile = Depends(get_current_user),
) -> list[ScoutCatalogItem]:
"""Return the static list of available agent types and their descriptions."""
return [
ScoutCatalogItem(
type="local_directory",
name="Local Directory Monitor",
description="Watches local directories, extracts data from files using AI",
),
ScoutCatalogItem(
type="gmail",
name="Gmail Connector",
description="Scans Gmail inbox, extracts tasks/notes from emails",
),
ScoutCatalogItem(
type="teams",
name="Microsoft Teams Connector",
description="Monitors Teams messages, extracts action items",
),
ScoutCatalogItem(
type="outlook",
name="Outlook Connector",
description="Scans Outlook inbox, extracts tasks/notes",
),
]
@router.post("/can-create", response_model=ScoutCreationCheckResponse)
async def can_create_agent(
body: ScoutCreationCheckRequest,
current_user: UserProfile = Depends(get_current_user),
) -> ScoutCreationCheckResponse:
"""Check if the user can create one more agent based on billing tier.
Since configuration is client-owned, the Electron app sends its current
active agent count and the backend applies tier limits.
"""
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
allowed = limit == -1 or body.active_agents < limit
return ScoutCreationCheckResponse(
allowed=allowed,
tier=current_user.tier,
active_agents=body.active_agents,
limit=limit,
)
@router.post("/trigger", response_model=ScoutRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
async def trigger_agent_run(
body: ScoutTriggerRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> ScoutRunLogResponse:
"""Trigger a local agent run using client-provided configuration."""
_enforce_agent_limit(current_user.tier, body.active_agents)
await _enforce_run_frequency(current_user.tier, current_user.id, db)
last_run_dt = (
datetime.fromtimestamp(body.last_run_at / 1000, tz=timezone.utc)
if body.last_run_at
else None
)
config = LocalScoutConfig(
id=str(uuid.uuid4()),
user_id=current_user.id,
device_id=body.device_id,
name="Local Directory Monitor",
directory_paths=[body.directory],
data_types=_to_data_types(body.what_to_extract),
prompt_template=body.custom_agent_prompt or "",
scout_config=body.agent_config,
file_extensions=[],
schedule_cron=body.batch_interval,
enabled=True,
last_run_at=last_run_dt,
)
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
stable_agent_id = body.agent_id or config.id
if is_agent_running(stable_agent_id):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Agent is already running. Only one run per agent is allowed at a time.",
)
run_log = ScoutRunLog(
scout_id=stable_agent_id,
scout_type="local",
user_id=current_user.id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_context = {
"type": "agent_batch",
"run_id": run_log.id,
"agent_id": stable_agent_id,
}
asyncio.create_task(
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
)
return _to_run_log_response(run_log)
# ── Note summary endpoint ──────────────────────────────────────────────────────
class NoteSummarizeRequest(BaseModel):
title: str
content: str
class NoteSummarizeResponse(BaseModel):
summary: str
@router.post("/notes/summarize", response_model=NoteSummarizeResponse)
async def summarize_note(
body: NoteSummarizeRequest,
current_user: UserProfile = Depends(get_current_user),
) -> NoteSummarizeResponse:
"""Generate an AI summary for a note. Used by the Electron backfill on startup."""
summary = await generate_note_summary(body.title, body.content)
return NoteSummarizeResponse(summary=summary)
# ── Gmail OAuth setup (scout-specific) ───────────────────────────────────────
# Scopes required for Gmail scout connectivity.
_GMAIL_SCOUT_SCOPES = [
"openid",
"email",
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.modify",
]
# Google OAuth endpoints.
_GOOGLE_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
_GOOGLE_TOKEN_URL = "https://oauth2.googleapis.com/token"
# In-memory pending OAuth states for scout Gmail consent:
# state → (code_verifier, scout_id, user_id, expires_at_epoch_s)
# Production note: replace with Redis for multi-process deployments.
_pending_scout_oauth_states: dict[str, tuple[str, str, str, float]] = {}
_SCOUT_OAUTH_TTL_SECONDS = 600 # 10 minutes
def _scout_gmail_redirect_uri() -> str:
"""Derive the scout Gmail web-callback URI from the configured base OAUTH_REDIRECT_URI.
``OAUTH_REDIRECT_URI`` is the full path used for login OAuth
(e.g. http://localhost:8000/api/v1/auth/oauth/google/web-callback).
We strip the path to get the scheme+host base, then append the scout path.
"""
parsed = urllib.parse.urlparse(settings.OAUTH_REDIRECT_URI)
base = f"{parsed.scheme}://{parsed.netloc}"
return f"{base}/api/v1/scouts/oauth/gmail/web-callback"
class _ScoutGmailAuthorizeResponse(BaseModel):
authorize_url: str
class _ScoutGmailCallbackBody(BaseModel):
code: str
state: str
@router.get("/oauth/gmail/authorize", response_model=_ScoutGmailAuthorizeResponse)
async def scout_gmail_oauth_authorize(
scout_id: str,
current_user: UserProfile = Depends(get_current_user),
) -> _ScoutGmailAuthorizeResponse:
"""Start the Gmail OAuth flow for a specific cloud scout.
Returns the Google consent-screen URL. The client opens this URL in the
system browser; after consent Google redirects to web-callback which bounces
to the ``adiuvai://scout/oauth/gmail/callback`` deep link.
"""
if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET:
raise HTTPException(
status.HTTP_503_SERVICE_UNAVAILABLE,
"Google OAuth is not configured on this server",
)
code_verifier, code_challenge = generate_pkce_pair()
state = secrets.token_urlsafe(32)
# Purge expired states to prevent unbounded growth.
now = time.time()
expired = [s for s, (_, _, _, exp) in _pending_scout_oauth_states.items() if exp < now]
for s in expired:
del _pending_scout_oauth_states[s]
_pending_scout_oauth_states[state] = (code_verifier, scout_id, current_user.id, now + _SCOUT_OAUTH_TTL_SECONDS)
redirect_uri = _scout_gmail_redirect_uri()
params = {
"client_id": settings.GOOGLE_AUTH_CLIENT_ID,
"redirect_uri": redirect_uri,
"response_type": "code",
"scope": " ".join(_GMAIL_SCOUT_SCOPES),
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
"access_type": "offline",
"prompt": "consent",
}
authorize_url = f"{_GOOGLE_AUTH_URL}?{urllib.parse.urlencode(params)}"
return _ScoutGmailAuthorizeResponse(authorize_url=authorize_url)
@router.get("/oauth/gmail/web-callback", include_in_schema=False)
async def scout_gmail_oauth_web_callback(code: str, state: str) -> RedirectResponse:
"""Google redirects here after Gmail consent.
Immediately bounces to the Electron deep link so the desktop app
receives the authorization code.
"""
params = urllib.parse.urlencode({"code": code, "state": state})
deep_link = f"adiuvai://scout/oauth/gmail/callback?{params}"
return RedirectResponse(url=deep_link, status_code=302)
@router.post("/oauth/gmail/callback")
async def scout_gmail_oauth_callback(
body: _ScoutGmailCallbackBody,
db: AsyncSession = Depends(get_session),
current_user: UserProfile = Depends(get_current_user),
) -> dict:
"""Exchange the Gmail authorization code and store the encrypted token on the scout.
Called by the Electron app after it receives the deep-link callback with
the ``code`` and ``state`` params.
"""
entry = _pending_scout_oauth_states.pop(body.state, None)
if entry is None or entry[3] < time.time() or entry[2] != current_user.id:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state")
code_verifier, scout_id, _, _ = entry
redirect_uri = _scout_gmail_redirect_uri()
import httpx
async with httpx.AsyncClient() as client:
response = await client.post(
_GOOGLE_TOKEN_URL,
data={
"client_id": settings.GOOGLE_AUTH_CLIENT_ID,
"client_secret": settings.GOOGLE_AUTH_CLIENT_SECRET,
"code": body.code,
"code_verifier": code_verifier,
"grant_type": "authorization_code",
"redirect_uri": redirect_uri,
},
)
try:
response.raise_for_status()
except httpx.HTTPStatusError as exc:
logger.error("Gmail token exchange failed: %s", exc.response.text)
raise HTTPException(status.HTTP_502_BAD_GATEWAY, "Failed to exchange Gmail authorization code")
token_data = response.json()
creds_dict: dict = {
"token": token_data["access_token"],
"refresh_token": token_data.get("refresh_token"),
"token_uri": _GOOGLE_TOKEN_URL,
"client_id": settings.GOOGLE_AUTH_CLIENT_ID,
"client_secret": settings.GOOGLE_AUTH_CLIENT_SECRET,
"scopes": [
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.modify",
],
}
encrypted = encrypt_token(creds_dict)
scout = await db.get(CloudScoutConfig, scout_id)
if scout is None or scout.user_id != current_user.id:
raise HTTPException(status.HTTP_404_NOT_FOUND, "Scout not found")
scout.oauth_token_encrypted = encrypted
await db.commit()
# Attempt to set up Gmail push watch so we start receiving Pub/Sub notifications.
from app.scouts.connectors.registry import get_connector
try:
connector = get_connector("gmail")
await connector.setup_watch(scout)
await db.commit()
except KeyError:
logger.warning("gmail connector not registered — skipping setup_watch for scout %s", scout_id)
except Exception:
logger.exception("setup_watch failed for scout %s", scout_id)
return {"ok": True}

139
app/billing/quota.py Normal file
View File

@@ -0,0 +1,139 @@
"""Quota checks and atomic token-usage accounting for folder integration."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from sqlalchemy import select, update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from app.billing.tier_manager import TierManager
from app.models import MonthlyTokenUsage
from app.schemas import BillingTier
class QuotaExceeded(Exception):
"""Raised when a folder operation cannot proceed under the user's tier."""
def __init__(self, reason: str, message: str) -> None:
super().__init__(message)
self.reason = reason # "max_files" | "monthly_tokens"
@dataclass
class TokenUsageResult:
tokens_used: int
exhausted: bool
def _current_year_month() -> str:
return datetime.now(timezone.utc).strftime("%Y-%m")
_tier_manager = TierManager()
async def check_folder_quota(
*,
user_id: str,
tier: BillingTier,
estimated_files: int,
db: AsyncSession,
) -> None:
"""Raise QuotaExceeded if folder_max_files or folder_monthly_tokens
would be violated. -1 in either feature means unlimited."""
max_files = _tier_manager.get_feature_value(tier, "folder_max_files")
if max_files != -1 and estimated_files > max_files:
raise QuotaExceeded(
"max_files",
f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.",
)
cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens")
if cap == -1:
return
ym = _current_year_month()
row = (
await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == user_id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == "folder_index",
)
)
).scalar_one_or_none()
used = row.tokens_used if row else 0
if used >= cap:
raise QuotaExceeded(
"monthly_tokens",
f"Monthly token budget exhausted ({used}/{cap}); resets next month.",
)
async def add_token_usage(
*,
user_id: str,
feature: str,
tokens: int,
db: AsyncSession,
cap: int | None = None,
) -> TokenUsageResult:
"""Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature).
Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls
back to a read-then-write on other engines (e.g. aiosqlite in tests).
Returns post-update total and whether cap is exhausted.
"""
ym = _current_year_month()
# Detect dialect to choose between native upsert and portable fallback.
dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr]
if dialect_name == "postgresql":
# Native atomic upsert — production path.
stmt = (
pg_insert(MonthlyTokenUsage)
.values(
user_id=user_id,
year_month=ym,
feature=feature,
tokens_used=tokens,
)
.on_conflict_do_update(
index_elements=["user_id", "year_month", "feature"],
set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens},
)
.returning(MonthlyTokenUsage.tokens_used)
)
used: int = (await db.execute(stmt)).scalar_one()
await db.commit()
else:
# Portable fallback — used in tests (SQLite) and any non-PG engine.
row = (
await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == user_id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == feature,
)
)
).scalar_one_or_none()
if row is None:
row = MonthlyTokenUsage(
user_id=user_id,
year_month=ym,
feature=feature,
tokens_used=tokens,
)
db.add(row)
else:
row.tokens_used += tokens
await db.commit()
await db.refresh(row)
used = row.tokens_used
exhausted = cap is not None and cap != -1 and used >= cap
return TokenUsageResult(tokens_used=used, exhausted=exhausted)

View File

@@ -29,6 +29,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": False, # batch queue (Phase 2) "realtime_extraction": False, # batch queue (Phase 2)
"relational_memory": False, # relational tier (Phase 3) — Pro+ "relational_memory": False, # relational tier (Phase 3) — Pro+
"proactive_mining": False, # Power+ only (Phase 5) "proactive_mining": False, # Power+ only (Phase 5)
"folder_max_files": 200,
"folder_monthly_tokens": 100_000,
}, },
"pro": { "pro": {
"agents": -1, # unlimited "agents": -1, # unlimited
@@ -41,6 +43,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": True, # fire-and-forget asyncio.create_task "realtime_extraction": True, # fire-and-forget asyncio.create_task
"relational_memory": True, # person/project predicates "relational_memory": True, # person/project predicates
"proactive_mining": False, # Power+ only (Phase 5) "proactive_mining": False, # Power+ only (Phase 5)
"folder_max_files": 5000,
"folder_monthly_tokens": 2_000_000,
}, },
"power": { "power": {
"agents": -1, "agents": -1,
@@ -53,6 +57,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": True, "realtime_extraction": True,
"relational_memory": True, # all predicates incl. custom "relational_memory": True, # all predicates incl. custom
"proactive_mining": True, # scheduled pattern mining (Phase 5) "proactive_mining": True, # scheduled pattern mining (Phase 5)
"folder_max_files": -1, # unlimited
"folder_monthly_tokens": -1, # unlimited
}, },
"team": { "team": {
"agents": -1, "agents": -1,
@@ -65,6 +71,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": True, "realtime_extraction": True,
"relational_memory": True, # all predicates incl. custom "relational_memory": True, # all predicates incl. custom
"proactive_mining": True, # scheduled pattern mining (Phase 5) "proactive_mining": True, # scheduled pattern mining (Phase 5)
"folder_max_files": -1, # unlimited
"folder_monthly_tokens": -1, # unlimited
}, },
} }
@@ -123,6 +131,13 @@ class TierManager:
) )
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def get_feature_value(self, tier: BillingTier, feature: str) -> int:
"""Return integer feature value for tier. -1 means unlimited."""
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
if not isinstance(value, int):
return 0
return value
# ── Rate limiting ──────────────────────────────────────────────────── # ── Rate limiting ────────────────────────────────────────────────────
def get_rate_limit(self, tier: BillingTier) -> int: def get_rate_limit(self, tier: BillingTier) -> int:

View File

@@ -23,12 +23,12 @@ class Settings(BaseSettings):
LLM_EMBED_MODEL: str = "text-embedding-3-small" LLM_EMBED_MODEL: str = "text-embedding-3-small"
# Per-agent model overrides. Leave empty to fall back to LLM_MODEL. # Per-agent model overrides. Leave empty to fall back to LLM_MODEL.
LLM_MODEL_CLASSIFIER: str = "" # _infer_floating_domain (intent routing) LLM_MODEL_CLASSIFIER: str = "" # classifier (intent routing, future use)
LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream) LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream)
LLM_MODEL_FLOATING_AGENT: str = "" # floating-agent (contextual chat)
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner) LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner) LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs) LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs)
LLM_MODEL_TASK_BRIEF_AGENT: str = "" # task-brief-agent (per-task deep research)
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide) LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining) LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)
@@ -58,6 +58,16 @@ class Settings(BaseSettings):
# Prod: https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback # Prod: https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback
OAUTH_REDIRECT_URI: str = "http://localhost:8000/api/v1/auth/oauth/google/web-callback" OAUTH_REDIRECT_URI: str = "http://localhost:8000/api/v1/auth/oauth/google/web-callback"
# Gmail Pub/Sub topic for push notifications.
# Full resource name, e.g. "projects/my-project/topics/gmail-push".
# Leave empty in dev — setup_watch will skip registration gracefully.
GMAIL_PUBSUB_TOPIC: str = ""
# OIDC token audience for Pub/Sub push subscription JWT verification.
# Set to the service account email or audience string configured in the
# Pub/Sub push subscription. Leave empty in dev to skip verification
# (a warning is logged — never silent in production).
GMAIL_PUBSUB_AUDIENCE: str = ""
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth # Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
# tokens stored in cloud_agent_configs.oauth_token_encrypted. # tokens stored in cloud_agent_configs.oauth_token_encrypted.
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key() # Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()

View File

@@ -21,6 +21,7 @@ from app.core.deep_agent import (
_relational_memory_injection, _relational_memory_injection,
_run_single_agent_stream, _run_single_agent_stream,
_trace_id_from_context, _trace_id_from_context,
build_brief_multi_project_manifest,
) )
from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback
@@ -159,6 +160,8 @@ async def run_home_brief(
Yields (event_type, data) tuples identical to _run_single_agent_stream. Yields (event_type, data) tuples identical to _run_single_agent_stream.
Do NOT post-process output through _normalize_tagged_list_lines. Do NOT post-process output through _normalize_tagged_list_lines.
""" """
from app.agents.folder_agent import FOLDER_TOOLS
trace_id = _trace_id_from_context(context) trace_id = _trace_id_from_context(context)
today = date.today().isoformat() today = date.today().isoformat()
language = _resolve_language(context) language = _resolve_language(context)
@@ -171,7 +174,10 @@ async def run_home_brief(
if today not in system_prompt: if today not in system_prompt:
system_prompt += f"\nToday is {today}." system_prompt += f"\nToday is {today}."
tools = _build_read_tools(user_id, trace_id) brief_manifest = await build_brief_multi_project_manifest()
system_prompt = system_prompt + ("\n\n" + brief_manifest if brief_manifest else "")
tools = [*_build_read_tools(user_id, trace_id), *FOLDER_TOOLS]
async for event in _run_single_agent_stream( async for event in _run_single_agent_stream(
user_id=user_id, user_id=user_id,
system_prompt=system_prompt, system_prompt=system_prompt,

View File

@@ -1,4 +1,4 @@
"""Single-agent runners for home and floating chat contexts.""" """Single-agent runners for home and contextual chat contexts."""
from __future__ import annotations from __future__ import annotations
@@ -7,16 +7,18 @@ import logging
import re import re
from datetime import date from datetime import date
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Any, Literal from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool from langchain_core.tools import tool
from app.agents.client_agent import CLIENT_TOOLS
from app.agents.note_agent import NOTE_TOOLS from app.agents.note_agent import NOTE_TOOLS
from app.agents.project_agent import PROJECT_TOOLS from app.agents.project_agent import PROJECT_TOOLS
from app.agents.relations_agent import make_query_relations_tool
from app.agents.task_agent import TASK_TOOLS from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS from app.agents.timeline_agent import TIMELINE_TOOLS
from app.core.agent_session_buffer import session_buffer from app.core.scout_session_buffer import session_buffer
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
from app.core.llm import get_agent_llm, model_for_agent from app.core.llm import get_agent_llm, model_for_agent
from app.core.memory_middleware import MemoryMiddleware from app.core.memory_middleware import MemoryMiddleware
@@ -27,9 +29,6 @@ logger = logging.getLogger(__name__)
MAX_HISTORY_TURNS = 20 MAX_HISTORY_TURNS = 20
FloatingDomainType = Literal["task", "timeline", "project", "node"]
FloatingDomainSection = Literal["task", "timeline", "note"]
# Mapping of core-memory language values to natural-language names for prompts. # Mapping of core-memory language values to natural-language names for prompts.
_LANGUAGE_NAMES: dict[str, str] = { _LANGUAGE_NAMES: dict[str, str] = {
"en": "English", "it": "Italian", "es": "Spanish", "en": "English", "it": "Italian", "es": "Spanish",
@@ -58,6 +57,93 @@ def _language_instruction(context: dict[str, Any]) -> str:
f"All your output text must be written in {lang}." f"All your output text must be written in {lang}."
) )
MANIFEST_TOKEN_BUDGET = 3000 # rough budget for <linked_folder> block
def format_folder_manifest(manifest: dict | None) -> str:
"""Format a folder manifest into the <linked_folder> block.
Truncates by mtime DESC if estimated tokens exceed MANIFEST_TOKEN_BUDGET.
Returns empty string if manifest is None or has no files.
"""
if not manifest or not manifest.get("files"):
return ""
files = list(manifest["files"])
files.sort(key=lambda f: f.get("mtimeMs", 0), reverse=True)
header = (
f"<linked_folder>\npath: {manifest.get('folderPath', '?')} "
f"({len(files)} files, scanned {manifest.get('lastScannedAt', '?')})\nfiles:\n"
)
footer_template = "{} more files omitted, use read_project_folder_file to access by path\n</linked_folder>"
char_budget = MANIFEST_TOKEN_BUDGET * 4 # ~4 chars/token
body = ""
included = 0
for f in files:
line = f"- /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}\n"
if len(header) + len(body) + len(line) + len(footer_template.format(0)) > char_budget:
break
body += line
included += 1
omitted = len(files) - included
if omitted > 0:
return header + body + footer_template.format(omitted)
return header + body + "</linked_folder>"
async def _fetch_project_manifest(project_id: str) -> dict | None:
"""Fetch manifest from Electron via execute_on_client. Returns None if unlinked or error."""
from app.core.ws_context import execute_on_client
try:
result = await execute_on_client(
action="read_project_folder_manifest",
data={"projectId": project_id},
)
if not result or not result.get("folderPath"):
return None
return result
except Exception:
return None
async def build_brief_multi_project_manifest() -> str:
"""Build a compact multi-project manifest for the daily brief agent.
Calls execute_on_client('list_projects_with_folder_manifests') and keeps
the top 5 most-recently-modified files per project.
"""
try:
result = await execute_on_client(
action="list_projects_with_folder_manifests",
data={},
)
except Exception:
return ""
projects = (result or {}).get("projects") or []
if not projects:
return ""
blocks: list[str] = ["<linked_folders>"]
any_entry = False
for p in projects:
all_files = p.get("files", []) or []
files = sorted(all_files, key=lambda f: f.get("mtimeMs", 0), reverse=True)[:5]
blocks.append(f"project: {p.get('projectName','?')} [{p.get('projectId','?')}]")
blocks.append(f" path: {p.get('folderPath','?')} (scanned {p.get('lastScannedAt','?')})")
if not all_files:
blocks.append(" (no indexed files yet — folder is linked but empty or unscanned)")
else:
for f in files:
blocks.append(f" - /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}")
if len(all_files) > 5:
blocks.append(f"{len(all_files) - 5} more files (use read_project_folder_file by relPath)")
any_entry = True
if not any_entry:
return ""
blocks.append("</linked_folders>")
return "\n".join(blocks)
def _datetime_context_injection(context: dict[str, Any]) -> str: def _datetime_context_injection(context: dict[str, Any]) -> str:
"""Build a comprehensive DATE CONTEXT block with pre-computed ms-epoch boundaries for common ranges.""" """Build a comprehensive DATE CONTEXT block with pre-computed ms-epoch boundaries for common ranges."""
fp = context.get("format_prefs") fp = context.get("format_prefs")
@@ -265,30 +351,63 @@ For "today" / "tomorrow" queries, prefer list_tasks_due_today / list_timelines_t
{request_context}\ {request_context}\
""" """
_FLOATING_SYSTEM_PROMPT = """\ _CONTEXTUAL_SYSTEM_PROMPT = """You are adiuvAI's contextual assistant. The user is working inside the app and has opened a side chat anchored to a specific view ("current view"). Help them act on that view: recap, plan, create entities, answer questions.
You are adiuvAI's floating executive assistant.{user_identity}
You are pinned to a specific entity (task, timeline event, project, or note) and you stay strictly within that scope.
Be a proactive partner: anticipate the next useful action and close with a concrete suggestion or a clarifying question — but stay terse, one short paragraph at most.
# How you work Rules:
- Use tools before answering anything factual. Never guess. 1. Base context (current view summary) is provided every turn. Treat it as ground truth for ids and names; never invent them.
- Stay in the floating scope (see Request context). If the user asks something outside scope, answer briefly and suggest opening the home assistant. 2. ALL reads go through `get_page_details`. The legacy tools `list_projects`, `get_project`, `list_tasks`, `get_task`, `list_notes`, `get_note` are NOT available in this channel — do not attempt to call them. To find an entity by name, call `get_page_details({entityType: 'projects_all' | 'tasks_all' | 'timeline_all'})` to list, then `get_page_details({entityType: '<type>', entityId})` for the full snapshot.
- Match the user's tone preference. Default to warm-but-direct. 3. When the user requests an action that creates or updates an entity:
- When the user asks to remember, forget, or update something, use memory tools. - If the current view is a project and no project is specified, use the current project automatically.
- If the current view is the global Tasks / Projects / Timeline list and no project is specified, ASK before attaching to any project. Don't silently create orphan entities.
4. The current view can change mid-conversation (user navigates). When you see a system message "User navigated to ...", treat the new view as the active context. Prior turns remain visible but the active scope shifts.
5. Notes: you can read note bodies via `get_page_details({entityType:'note'})`. You CANNOT edit, summarize-to-replace, or append. Tell the user "note editing is coming in a later release" if asked.
6. Be concise. Default to 1-3 short paragraphs. Bullet lists fine. Don't restate the user's request.
7. Never expose ids in prose. Use names. Ids only travel through tool calls.
# Filter discipline # Date context
- Never set the `assignee` filter on list_tasks/count_tasks unless the user explicitly names a person ("Marco's tasks") or refers to themselves ("my tasks", "assigned to me", "mine").
- The user's own name in the User profile block is for context only — it is NOT a default filter.
- When in doubt, omit `assignee` and return the global result.
# Output format
Plain text only. Do NOT output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed-id wrappers, and do NOT output <chart> blocks — those are for the home assistant.
# Date filtering
{date_context} {date_context}
When filtering by date, take dueDateFrom / dueDateTo (ms epoch UTC) verbatim from the DATE CONTEXT boundary table above. Do NOT compute boundaries from now_ms yourself. # Language
For specific dates not listed, compute local-midnight in the user timezone and convert to UTC ms. {language_instruction}
"""
_TASK_BRIEF_RESEARCH_SYSTEM_PROMPT = """\
You are an executive assistant preparing a briefing dossier for your principal before they act on a specific task.
Your job: gather all relevant context, synthesize it into a tight actionable dossier, and — if the task requires writing (email, message, document) — produce a ready-to-use draft.{user_identity}
# Research workflow
Follow these steps in order, using tools:
1. Read the task fully (title, description, due date, priority, status, project, comments).
2. Fetch the parent project (`get_project`) to understand scope, aiSummary, and any linked client.
3. If the project has a clientId: call `get_client(id)` to retrieve full client details.
4. Call `query_relations` (subject_label=client_name or task subject) to find cross-project connections — e.g. the same client appearing in multiple projects.
5. Search associative memory (`search_associative`) and archival memory (`archival_memory_search`) using the task title + client name as query phrases to surface relevant past interactions.
6. Read core memory blocks for tone preference, language, and user style: `memory_get("tone_preference")`, `memory_get("language")`.
7. Determine task kind: is this a writing task (email reply, message, follow-up, proposal)? If yes, draft a ready-to-send piece.
# Output structure
Write the briefing in the user's language. Use this exact structure:
**What needs to be done**
(12 sentences, concrete and specific — what action the user must take)
**Context you should know**
(bullet points covering: client background, related projects, prior interactions, tone/style notes, any relevant deadlines or dependencies)
**Suggested first step**
(one specific, immediately actionable instruction)
If this is a writing task, append a canvas block at the very end:
<canvas kind="email|document|message">
...ready-to-use draft here...
</canvas>
Do NOT include the canvas block for non-writing tasks.
Do NOT repeat verbatim task fields the user already sees in the UI.
Be concrete — no vague advice. Every bullet should be a fact that changes what the user does.
# Date context
{date_context}
# Language # Language
{language_instruction} {language_instruction}
@@ -296,25 +415,35 @@ For specific dates not listed, compute local-midnight in the user timezone and c
# Known people & projects # Known people & projects
{relational_memory} {relational_memory}
# Behavioral hints
{proactive_hints}
# Request context # Request context
{request_context}\ {request_context}\
""" """
_FLOATING_DOMAIN_CLASSIFIER_PROMPT = ( _TASK_BRIEF_FOLLOWUP_SYSTEM_PROMPT = """\
"You are a strict domain classifier for websocket floating requests. " You are an executive assistant continuing a conversation with your principal.
"Return ONLY a JSON object with keys: type, id, section. " You have already prepared and delivered a research briefing for the active task. The user has read it.{user_identity}
"Allowed type values: task, timeline, project, node. "
"Allowed section values: task, timeline, note, or null. "
"Rules: infer from user message intent first; do not blindly trust scope.type. "
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
"If project id is unknown but context.resolved_project_id exists, use it as id. "
"If id is unknown, use null. "
"No markdown, no prose, JSON only."
)
Your briefing:
---
{briefing_context}
---
Continue from here. Do NOT repeat the briefing. Refer to it when relevant.
Help the user execute: edit drafts, refine wording, look up additional details, plan next steps.
Stay terse — your principal is a busy executive.
# Date context
{date_context}
# Language
{language_instruction}
# Known people & projects
{relational_memory}
# Request context
{request_context}\
"""
def _as_text(content: Any) -> str: def _as_text(content: Any) -> str:
if content is None: if content is None:
@@ -393,6 +522,55 @@ def _all_tools() -> list[Any]:
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS] return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
# ── Contextual sidebar tools ──────────────────────────────────────────
@tool
async def get_page_details(
entity_type: str = "",
entity_id: str = "",
) -> str:
"""Fetch full details for the entity currently in view.
entity_type: one of 'project' | 'task' | 'note' | 'timeline_event' |
'tasks_all' | 'projects_all' | 'timeline_all'.
entity_id: UUID of the entity for singular entity views. Omit for list views.
The Electron drizzle-executor fulfils this op against local SQLite and
returns the row(s) as a JSON tool result.
"""
result = await execute_on_client(
action="get_page_details",
table=entity_type or "unknown",
data={"entityId": entity_id or None},
)
if not result:
return "No details found."
return str(result)
def _contextual_tools(user_id: str, trace_id: str | None) -> list[Any]:
"""Return the tool palette for the contextual sidebar agent.
Read ops go through get_page_details only — legacy list_*/get_* tools
return shallow snapshots and cause the agent to under-answer (see
smoke trace 0b46841484ba7d024ed9f8d5ac8b1df0). Writes are limited
to entity creation + task update; note edits are next-sprint.
"""
from app.agents.note_agent import create_note # noqa: PLC0415
from app.agents.task_agent import create_task, update_task # noqa: PLC0415
from app.agents.timeline_agent import create_timeline # noqa: PLC0415
return [
get_page_details,
create_task,
update_task,
create_note,
create_timeline,
*_memory_tools(user_id, trace_id),
]
def _trace_id_from_context(context: dict[str, Any]) -> str | None: def _trace_id_from_context(context: dict[str, Any]) -> str | None:
debug = context.get("_debug") debug = context.get("_debug")
if isinstance(debug, dict): if isinstance(debug, dict):
@@ -495,70 +673,6 @@ def _normalize_tagged_list_lines(text: str, message: str) -> str:
return "\n".join(output_lines) return "\n".join(output_lines)
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
_FLOATING_EMPTY_FALLBACK = "No results found."
def _strip_floating_markup_fragment(text: str) -> str:
if not text:
return text
cleaned = _GENERIC_TAG_RE.sub("", text)
return _BRACKETED_ID_RE.sub("", cleaned)
def _strip_floating_markup(text: str) -> str:
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
if not text:
return text
cleaned = _strip_floating_markup_fragment(text)
# Collapse excessive spaces introduced by tag/id removal while preserving lines.
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
return "\n".join(line for line in lines if line)
def _fallback_from_raw_floating_text(raw_text: str) -> str:
fallback = _strip_floating_markup_fragment(raw_text or "")
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
return fallback or _FLOATING_EMPTY_FALLBACK
class _FloatingStreamSanitizer:
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
def __init__(self) -> None:
self._pending = ""
@staticmethod
def _split_safe_boundary(text: str) -> tuple[str, str]:
boundary = len(text)
last_lt = text.rfind("<")
if last_lt != -1 and ">" not in text[last_lt:]:
boundary = min(boundary, last_lt)
last_lb = text.rfind("[")
if last_lb != -1 and "]" not in text[last_lb:]:
boundary = min(boundary, last_lb)
if boundary == len(text):
return text, ""
return text[:boundary], text[boundary:]
def feed(self, chunk: str) -> str:
combined = f"{self._pending}{chunk}"
safe_text, self._pending = self._split_safe_boundary(combined)
return _strip_floating_markup_fragment(safe_text)
def finalize(self) -> str:
# Drop dangling unfinished wrappers at the very end.
tail = re.sub(r"<[^>\n]*$", "", self._pending)
tail = re.sub(r"\[[^\]\n]*$", "", tail)
self._pending = ""
return _strip_floating_markup_fragment(tail)
def _normalize_memory_label(path_or_label: str) -> str: def _normalize_memory_label(path_or_label: str) -> str:
value = path_or_label.strip() value = path_or_label.strip()
if value.startswith("/memories/"): if value.startswith("/memories/"):
@@ -679,6 +793,25 @@ def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
lines = [f"- {item}" for item in results] lines = [f"- {item}" for item in results]
return "Recall memory results:\n" + "\n".join(lines) return "Recall memory results:\n" + "\n".join(lines)
@tool
async def search_associative(query: str, limit: int = 5) -> str:
"""Semantic search across associative (archival) memory for a given query.
Use this to surface long-term memories related to a topic, client, or task
that may not appear in recent episodes.
query: natural-language search phrase.
limit: max results (default 5).
"""
logger.info("deep_agent: search_associative trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
async with async_session() as db:
memory = MemoryMiddleware(db)
results = await memory.search_archival(user_id, query, top_k=limit)
if not results:
return "No associative memory results found."
lines = [f"- {item}" for item in results]
return "Associative memory results:\n" + "\n".join(lines)
return [ return [
memory_list_blocks, memory_list_blocks,
memory_get, memory_get,
@@ -689,182 +822,37 @@ def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
archival_memory_insert, archival_memory_insert,
archival_memory_search, archival_memory_search,
conversation_search, conversation_search,
search_associative,
] ]
def _read_only_memory_tools(user_id: str, trace_id: str | None) -> list[Any]: def _read_only_memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
"""Return memory tools that only read — safe for the read-only brief-agent subset.""" """Return memory tools that only read — safe for the read-only brief-agent subset."""
all_mem = _memory_tools(user_id, trace_id) all_mem = _memory_tools(user_id, trace_id)
_read_names = {"memory_list_blocks", "memory_get", "archival_memory_search", "conversation_search"} _read_names = {
"memory_list_blocks", "memory_get", "archival_memory_search",
"conversation_search", "search_associative",
}
return [t for t in all_mem if t.name in _read_names] return [t for t in all_mem if t.name in _read_names]
def _brief_research_tools(user_id: str, trace_id: str | None) -> list[Any]:
"""Return the full tool palette for Stage-1 task brief research (read-only)."""
return [
*TASK_TOOLS,
*PROJECT_TOOLS,
*NOTE_TOOLS,
*TIMELINE_TOOLS,
*CLIENT_TOOLS,
*_read_only_memory_tools(user_id, trace_id),
make_query_relations_tool(user_id, trace_id),
]
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]: def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
return [*_all_tools(), *_memory_tools(user_id, trace_id)] return [*_all_tools(), *_memory_tools(user_id, trace_id)]
def _detect_domain_section(message: str) -> FloatingDomainSection | None:
lowered = message.lower()
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
return "timeline"
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
return "task"
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
return "note"
return None
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
type_raw = str(payload.get("type") or "").strip().lower()
domain_type: FloatingDomainType = "task"
if type_raw in {"task", "timeline", "project", "node"}:
domain_type = type_raw
id_value = payload.get("id")
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
if domain_type == "project" and not domain_id:
domain_id = fallback_id
section_raw = payload.get("section")
section: FloatingDomainSection | None = None
if isinstance(section_raw, str):
section_candidate = section_raw.strip().lower()
if section_candidate in {"task", "timeline", "note"}:
section = section_candidate
if domain_type != "project":
section = None
return {
"type": domain_type,
"id": domain_id,
"section": section,
}
def _parse_json_object(text: str) -> dict[str, Any] | None:
raw = text.strip()
if not raw:
return None
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", raw, re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
except json.JSONDecodeError:
return None
return parsed if isinstance(parsed, dict) else None
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
section = _detect_domain_section(message)
scope = context.get("scope") if isinstance(context, dict) else None
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
if isinstance(scope, dict):
scope_type = str(scope.get("type") or "").strip().lower()
scope_id = scope.get("id")
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
if scope_type in {"task", "tasks"}:
return {"type": "task", "id": scope_id_value, "section": None}
if scope_type in {"project", "projects"}:
project_scope_id = scope_id_value or project_id
return {
"type": "project",
"id": project_scope_id,
"section": section,
}
if scope_type in {"note", "notes"}:
return {
"type": "node",
"id": scope_id_value,
"section": None,
}
if scope_type in {"timeline", "timelines"}:
return {"type": "timeline", "id": scope_id_value, "section": None}
lowered = message.lower()
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
return {
"type": "project",
"id": project_id,
"section": section,
}
if section == "timeline":
return {"type": "timeline", "id": None, "section": None}
if section == "note":
return {"type": "node", "id": None, "section": None}
return {"type": "task", "id": None, "section": None}
async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[str, str | None]:
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
classifier_context = {
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
"resolved_project_id": project_id,
}
try:
llm = get_agent_llm("classifier")
classifier_messages = [
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_PROMPT),
HumanMessage(
content=(
f"Message:\n{message}\n\n"
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
)
),
]
lf = get_langfuse()
_, classifier_prompt_obj = get_prompt_or_fallback(
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_PROMPT
)
# Extract user/session from context for Langfuse attribution
_debug = context.get("_debug") if isinstance(context, dict) else None
_lf_user = (_debug or {}).get("user_id") if isinstance(_debug, dict) else None
_lf_session = (_debug or {}).get("session_id") if isinstance(_debug, dict) else None
with langfuse_context(user_id=_lf_user, session_id=_lf_session):
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="floating-classifier",
model=model_for_agent("classifier"),
prompt=classifier_prompt_obj,
input=classifier_messages,
) as gen:
response = await llm.ainvoke(classifier_messages)
gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
else:
response = await llm.ainvoke(classifier_messages)
parsed = _parse_json_object(_as_text(response.content))
if parsed is not None:
domain = _normalize_domain_payload(parsed, project_id)
logger.info(
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
domain.get("type"),
domain.get("id"),
domain.get("section"),
)
return domain
logger.warning("deep_agent: floating_domain classifier returned non-json output")
except Exception as exc:
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
return _infer_floating_domain_rule_based(message, context)
def _history_to_messages(history: list[dict[str, str]] | None) -> list[Any]: def _history_to_messages(history: list[dict[str, str]] | None) -> list[Any]:
if not history: if not history:
return [] return []
@@ -1193,32 +1181,30 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
return _normalize_tagged_list_lines(response, message) return _normalize_tagged_list_lines(response, message)
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context)
system_prompt, langfuse_prompt = _build_system_prompt("floating_system", _FLOATING_SYSTEM_PROMPT, prepared_context)
response = await _run_single_agent(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_prompt=langfuse_prompt,
agent_name="floating-agent",
conversation_history=context.get("conversation_history"),
)
sanitized = _strip_floating_markup(response)
if not sanitized and response:
sanitized = _fallback_from_raw_floating_text(response)
return sanitized, domain
async def run_home_stream( async def run_home_stream(
user_id: str, user_id: str,
message: str, message: str,
context: dict[str, Any], context: dict[str, Any],
project_id: str | None = None,
) -> AsyncGenerator[tuple[str, Any], None]: ) -> AsyncGenerator[tuple[str, Any], None]:
from app.agents.folder_agent import FOLDER_TOOLS
prepared_context = await _prepare_context(message, context) prepared_context = await _prepare_context(message, context)
system_prompt, langfuse_prompt = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, prepared_context) system_prompt, langfuse_prompt = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, prepared_context)
manifest_block = ""
if project_id:
manifest = await _fetch_project_manifest(project_id)
manifest_block = format_folder_manifest(manifest)
if not manifest_block:
# No specific project context — surface all linked folders so the agent
# can answer questions like "tell me about project X" using its files.
manifest_block = await build_brief_multi_project_manifest()
system_prompt = system_prompt + ("\n\n" + manifest_block if manifest_block else "")
trace_id = _trace_id_from_context(prepared_context)
tools = [*_all_tools_for_user(user_id, trace_id), *FOLDER_TOOLS]
text_chunks: list[str] = [] text_chunks: list[str] = []
async for event in _run_single_agent_stream( async for event in _run_single_agent_stream(
user_id=user_id, user_id=user_id,
@@ -1227,6 +1213,7 @@ async def run_home_stream(
context=prepared_context, context=prepared_context,
langfuse_prompt=langfuse_prompt, langfuse_prompt=langfuse_prompt,
agent_name="home-agent", agent_name="home-agent",
tools=tools,
conversation_history=context.get("conversation_history"), conversation_history=context.get("conversation_history"),
): ):
event_type, data = event event_type, data = event
@@ -1240,47 +1227,99 @@ async def run_home_stream(
yield "token", normalized yield "token", normalized
async def run_floating_stream( async def run_contextual_stream(
user_id: str, user_id: str,
message: str, message: str,
context: dict[str, Any], context: dict[str, Any],
scope: "ContextualScope", # type: ignore[name-defined]
) -> AsyncGenerator[tuple[str, Any], None]: ) -> AsyncGenerator[tuple[str, Any], None]:
prepared_context = await _prepare_context(message, context) """Run the contextual agent for a single user turn.
domain = await _infer_floating_domain(message, prepared_context)
yield "floating_domain", domain Injects the rendered scope block into the system prompt and exposes
the contextual tool set.
Note-edit tools (propose_note_edit) are intentionally excluded.
*context contract*: callers MUST include ``context["_debug"]["session_id"]``
(a non-empty str) so that ``_session_id_from_context`` can extract it for
tracing and episode storage downstream. The WS handler in device_ws.py
satisfies this by always populating ``_debug`` before calling this function.
"""
from app.schemas.contextual import ContextualScope, render_scope_block # noqa: PLC0415
prepared_context = await _prepare_context(message, context)
trace_id = _trace_id_from_context(prepared_context)
system_prompt, langfuse_prompt = _build_system_prompt(
"contextual_system", _CONTEXTUAL_SYSTEM_PROMPT, prepared_context,
)
scope_block = render_scope_block(scope)
system_prompt = system_prompt + f"\n\n## Current view\n{scope_block}"
tools = _contextual_tools(user_id, trace_id)
system_prompt, langfuse_prompt = _build_system_prompt("floating_system", _FLOATING_SYSTEM_PROMPT, prepared_context)
sanitizer = _FloatingStreamSanitizer()
emitted_sanitized = False
raw_chunks: list[str] = []
async for event in _run_single_agent_stream( async for event in _run_single_agent_stream(
user_id=user_id, user_id=user_id,
system_prompt=system_prompt, system_prompt=system_prompt,
message=message, message=message,
context=prepared_context, context=prepared_context,
langfuse_prompt=langfuse_prompt, langfuse_prompt=langfuse_prompt,
agent_name="floating-agent", agent_name="contextual-agent",
tools=tools,
conversation_history=context.get("conversation_history"), conversation_history=context.get("conversation_history"),
): ):
event_type, data = event
if event_type != "token":
yield event yield event
continue
raw_chunk = str(data or "")
raw_chunks.append(raw_chunk)
sanitized_chunk = sanitizer.feed(raw_chunk)
if sanitized_chunk:
emitted_sanitized = True
yield "token", sanitized_chunk
tail = sanitizer.finalize() async def run_task_brief_research_stream(
if tail: user_id: str,
emitted_sanitized = True task_id: str,
yield "token", tail context: dict[str, Any],
project_id: str | None = None,
) -> AsyncGenerator[tuple[str, Any], None]:
"""Stage-1 executive assistant: deep research for one task.
if not emitted_sanitized and raw_chunks: Yields ``("token", chunk)`` events like other stream runners.
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks)) The final concatenated text may contain a ``<canvas kind="...">...</canvas>`` block
which the WS handler strips and emits as a ``canvas_draft`` mutation.
"""
from app.agents.folder_agent import FOLDER_TOOLS
prepared_context = await _prepare_context(f"task:{task_id}", context)
tools = [*_brief_research_tools(user_id, _trace_id_from_context(prepared_context)), *FOLDER_TOOLS]
# Inject task_id so the agent knows what to look up first.
research_message = (
f"Prepare a briefing dossier for task ID: {task_id}\n"
"Follow the research workflow: read the task, then project, then client, "
"then cross-project relations, then relevant memory. "
"End with a concrete suggested first step. "
"If this is a writing task, include a <canvas kind=\"...\"> draft."
)
system_prompt, langfuse_prompt = _build_system_prompt(
"task_brief_research_system",
_TASK_BRIEF_RESEARCH_SYSTEM_PROMPT,
prepared_context,
)
manifest_block = ""
if project_id:
manifest = await _fetch_project_manifest(project_id)
manifest_block = format_folder_manifest(manifest)
system_prompt = system_prompt + ("\n\n" + manifest_block if manifest_block else "")
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,
message=research_message,
context=prepared_context,
max_steps=12,
langfuse_prompt=langfuse_prompt,
agent_name="task-brief-agent",
tools=tools,
conversation_history=None,
):
yield event
async def update_core_memory(user_id: str, key: str, value: str) -> None: async def update_core_memory(user_id: str, key: str, value: str) -> None:

183
app/core/folder_indexer.py Normal file
View File

@@ -0,0 +1,183 @@
"""Per-file summarisation for project folder integration."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass
from langchain_core.messages import HumanMessage, SystemMessage
from pypdf import PdfReader
from docx import Document as DocxDocument
from app.core.langfuse_client import (
compile_prompt,
extract_usage,
get_langfuse,
get_prompt_or_fallback,
)
from app.core.llm import get_llm
_TEXT_FALLBACK = (
"You are summarising a file for an AI assistant that helps the user manage a project.\n"
"Produce a single sentence (<=30 words, <=200 chars) that captures the file's purpose "
"and most important detail.\nFile extension: {ext}\nFile name: {name}\nContent (truncated if long):\n{content}"
)
_IMAGE_FALLBACK = (
"You are summarising an image attached to a project folder.\n"
"Produce a single sentence (<=30 words, <=200 chars) describing what the image shows "
"and any obvious purpose (logo, screenshot, diagram, photo of a whiteboard, etc.)."
)
_MAX_INPUT_CHARS = 6000
@dataclass
class IndexResult:
summary: str
tokens_used: int
async def _llm_text(messages: list) -> object:
"""Make the LLM call for text summarisation.
Defined as a standalone async function so tests can patch it cleanly
without needing to mock the LLM object itself.
"""
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
return await llm.ainvoke(messages)
async def _llm_vision(messages: list) -> object:
"""Make the LLM call for vision (image) summarisation.
Accepts the message list and returns the response directly, mirroring
the ``_llm_text`` caller pattern so tests can patch it at the module level.
"""
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
return await llm.ainvoke(messages)
async def summarize_image(*, image_b64: str, mime: str, file_name: str | None = None) -> IndexResult:
"""Return a compact summary of an image file using vision.
Parameters
----------
image_b64:
Base64-encoded image bytes.
mime:
MIME type of the image, e.g. ``"image/png"``.
file_name:
Optional file name, attached to the Langfuse trace as input metadata.
"""
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_image", _IMAGE_FALLBACK)
messages = [
SystemMessage(content=template),
HumanMessage(content=[
{"type": "text", "text": "Summarise this image."},
{"type": "image_url", "image_url": {"url": f"data:{mime};base64,{image_b64}"}},
]),
]
lf = get_langfuse()
if lf is not None:
with lf.start_as_current_observation(
as_type="generation",
name="folder-summarize-image",
model="gpt-4o-mini",
prompt=prompt_obj,
input={"file_name": file_name, "mime": mime},
) as gen:
response = await _llm_vision(messages)
usage = extract_usage(response)
gen.update(output=response.content, usage_details=usage)
else:
response = await _llm_vision(messages)
usage = extract_usage(response)
summary = (response.content or "").strip()[:500]
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
async def summarize_text(*, content: str, ext: str, name: str) -> IndexResult:
"""Return a compact summary of a text file.
Parameters
----------
content:
Raw text content of the file (will be truncated to _MAX_INPUT_CHARS).
ext:
File extension including the leading dot, e.g. ``".md"``.
name:
File name, e.g. ``"kickoff.md"``.
"""
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_text", _TEXT_FALLBACK)
truncated = content[:_MAX_INPUT_CHARS]
compiled = compile_prompt(template, prompt_obj, ext=ext, name=name, content=truncated)
messages = [
SystemMessage(content=compiled),
HumanMessage(content="Summarise this file."),
]
lf = get_langfuse()
if lf is not None:
with lf.start_as_current_observation(
as_type="generation",
name="folder-summarize-text",
model="gpt-4o-mini",
prompt=prompt_obj,
input={"file_name": name, "ext": ext, "content_chars": len(truncated)},
) as gen:
response = await _llm_text(messages)
usage = extract_usage(response)
gen.update(output=response.content, usage_details=usage)
else:
response = await _llm_text(messages)
usage = extract_usage(response)
summary = (response.content or "").strip()[:500]
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
def _extract_pdf_text(pdf_b64: str) -> str:
buf = io.BytesIO(base64.b64decode(pdf_b64))
reader = PdfReader(buf)
parts: list[str] = []
for page in reader.pages:
try:
parts.append(page.extract_text() or "")
except Exception:
continue
return "\n".join(parts).strip()
def _extract_docx_text(docx_b64: str) -> str:
buf = io.BytesIO(base64.b64decode(docx_b64))
doc = DocxDocument(buf)
return "\n".join(p.text for p in doc.paragraphs if p.text).strip()
async def summarize_pdf(*, pdf_b64: str, name: str) -> IndexResult:
"""Return a compact summary of a PDF file.
Parameters
----------
pdf_b64:
Base64-encoded PDF bytes.
name:
File name, e.g. ``"report.pdf"``.
"""
text = _extract_pdf_text(pdf_b64)
if not text:
return IndexResult(summary="Could not extract text", tokens_used=0)
return await summarize_text(content=text, ext=".pdf", name=name)
async def summarize_docx(*, docx_b64: str, name: str) -> IndexResult:
"""Return a compact summary of a DOCX file.
Parameters
----------
docx_b64:
Base64-encoded DOCX bytes.
name:
File name, e.g. ``"spec.docx"``.
"""
text = _extract_docx_text(docx_b64)
if not text:
return IndexResult(summary="Could not extract text", tokens_used=0)
return await summarize_text(content=text, ext=".docx", name=name)

View File

@@ -103,14 +103,15 @@ def get_llm(
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = { _AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL, "classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL, "home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
"floating-agent": lambda: settings.LLM_MODEL_FLOATING_AGENT or settings.LLM_MODEL,
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL, "unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL, "cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
"brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL, "brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL,
"task-brief-agent": lambda: settings.LLM_MODEL_TASK_BRIEF_AGENT or settings.LLM_MODEL,
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL, "setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini", "memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini", "memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
"memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL, "memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL,
"note-summarizer": lambda: "gpt-4o-mini",
} }

View File

@@ -0,0 +1,51 @@
"""Note summarizer — generates a compact AI summary for a note.
Called fire-and-forget from create_note / update_note tools so the
``notes.ai_summary`` column stays current without blocking the agent loop.
"""
from __future__ import annotations
import logging
from langchain_core.messages import HumanMessage, SystemMessage
from app.core.langfuse_client import get_prompt_or_fallback
from app.core.llm import get_agent_llm
logger = logging.getLogger(__name__)
_FALLBACK_PROMPT = """\
Summarize this note in <=250 characters. Be terse and dense.
Keep proper nouns, dates, decisions, and action items.
Do not start with "This note".
Respond with the summary text only — no intro, no labels.
Title: {title}
Content: {content}"""
_MAX_CONTENT_CHARS = 4000
async def generate_note_summary(title: str, content: str) -> str:
"""Return a <=250-char summary of *title* + *content*.
Uses the Langfuse ``note_summary`` prompt (hot-swappable) with a local
fallback. Truncates *content* to 4000 chars before sending to avoid
token waste on large notes.
"""
template, _ = get_prompt_or_fallback("note_summary", _FALLBACK_PROMPT)
trimmed = content[:_MAX_CONTENT_CHARS]
system_prompt = template.format(title=title, content=trimmed)
try:
llm = get_agent_llm("note-summarizer")
response = await llm.ainvoke([
SystemMessage(content=system_prompt),
HumanMessage(content="Generate the summary."),
])
text = response.content if isinstance(response.content, str) else ""
return text.strip()[:250]
except Exception as exc:
logger.warning("note_summarizer: failed to generate summary: %s", exc)
return ""

View File

@@ -2,12 +2,36 @@
from __future__ import annotations from __future__ import annotations
import re
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Any from typing import Any
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText from app.schemas import WsStreamEnd, WsStreamStart, WsStreamText
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain # Matches <canvas kind="...">...</canvas> blocks (single-line or multiline).
_CANVAS_BLOCK_RE = re.compile(
r'<canvas\s+kind=["\']([^"\']+)["\']>(.*?)</canvas>',
re.DOTALL | re.IGNORECASE,
)
def extract_canvas_block(text: str) -> tuple[str, str | None, str | None]:
"""Strip the first <canvas kind="...">...</canvas> block from *text*.
Returns ``(visible_text, canvas_content, canvas_kind)``.
``canvas_content`` and ``canvas_kind`` are ``None`` when no block is found.
"""
match = _CANVAS_BLOCK_RE.search(text)
if not match:
return text, None, None
canvas_kind = match.group(1).strip()
canvas_content = match.group(2).strip()
visible = text[: match.start()] + text[match.end() :]
visible = visible.strip()
return visible, canvas_content, canvas_kind
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd
class StreamFormatter: class StreamFormatter:
@@ -23,14 +47,6 @@ class StreamFormatter:
started = False started = False
async for event_type, data in event_stream: async for event_type, data in event_stream:
if event_type == "floating_domain":
if isinstance(data, dict):
yield WsFloatingDomain(
request_id=self.request_id,
domain=data,
)
continue
if event_type != "token": if event_type != "token":
continue continue

View File

@@ -48,7 +48,7 @@ from app.core.llm import get_agent_llm, model_for_agent
from app.core.preprocessors import detect_content_type, preprocess from app.core.preprocessors import detect_content_type, preprocess
from app.core.ws_context import clear_client_executor, execute_on_client, set_client_executor from app.core.ws_context import clear_client_executor, execute_on_client, set_client_executor
from app.db import async_session from app.db import async_session
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig from app.models import ScoutRunLog, CloudScoutConfig, LocalScoutConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -169,7 +169,7 @@ def _is_overdue(schedule_cron: str, last_run_at: datetime | None) -> bool:
next_run: datetime = cron.get_next(datetime) next_run: datetime = cron.get_next(datetime)
return now >= next_run return now >= next_run
except Exception as exc: except Exception as exc:
logger.warning("agent_runner: cannot parse cron %r: %s", schedule_cron, exc) logger.warning("scout_runner: cannot parse cron %r: %s", schedule_cron, exc)
return False return False
@@ -290,7 +290,7 @@ async def _run_agent_with_tools(
call_name = str(call.get("name", "")) call_name = str(call.get("name", ""))
call_args = call.get("args", {}) call_args = call.get("args", {})
logger.info( logger.info(
"agent_runner: tool_call name=%s args=%s", "scout_runner: tool_call name=%s args=%s",
call_name, call_name,
json.dumps(call_args, ensure_ascii=True)[:800], json.dumps(call_args, ensure_ascii=True)[:800],
) )
@@ -305,7 +305,7 @@ async def _run_agent_with_tools(
tool_output = await tool_fn.ainvoke(call_args) tool_output = await tool_fn.ainvoke(call_args)
logger.info( logger.info(
"agent_runner: tool_result name=%s output=%s", "scout_runner: tool_result name=%s output=%s",
call_name, call_name,
str(tool_output)[:200], str(tool_output)[:200],
) )
@@ -360,7 +360,7 @@ async def _scan_directories(
try: try:
result = await execute_on_client(action="list_directory", data={"path": path}) result = await execute_on_client(action="list_directory", data={"path": path})
except Exception as exc: except Exception as exc:
logger.warning("agent_runner: list_directory failed %r: %s", path, exc) logger.warning("scout_runner: list_directory failed %r: %s", path, exc)
return return
for entry in result.get("entries", []): for entry in result.get("entries", []):
entry_path = entry.get("path", "") entry_path = entry.get("path", "")
@@ -414,7 +414,7 @@ async def _fetch_projects() -> list[dict]:
result = await execute_on_client(action="select", table="projects") result = await execute_on_client(action="select", table="projects")
return result.get("rows", []) return result.get("rows", [])
except Exception as exc: except Exception as exc:
logger.warning("agent_runner: failed to fetch projects: %s", exc) logger.warning("scout_runner: failed to fetch projects: %s", exc)
return [] return []
@@ -442,7 +442,7 @@ async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]:
) )
return result.get("rows", []) return result.get("rows", [])
except Exception as exc: except Exception as exc:
logger.warning("agent_runner: failed to fetch %s: %s", domain, exc) logger.warning("scout_runner: failed to fetch %s: %s", domain, exc)
return [] return []
@@ -555,8 +555,8 @@ def _get_no_match_behavior(agent_config: dict) -> str:
async def run_local_agent( async def run_local_agent(
user_id: str, user_id: str,
config: LocalAgentConfig, config: LocalScoutConfig,
run_log: AgentRunLog, run_log: ScoutRunLog,
device_mgr: DeviceConnectionManager, device_mgr: DeviceConnectionManager,
run_context: dict | None = None, run_context: dict | None = None,
) -> None: ) -> None:
@@ -586,7 +586,7 @@ async def run_local_agent(
if not is_online: if not is_online:
logger.info( logger.info(
"agent_runner: skip run=%s — device %r offline for user=%s", "scout_runner: skip run=%s — device %r offline for user=%s",
run_id, run_id,
target_device_id or "<any>", target_device_id or "<any>",
user_id, user_id,
@@ -605,7 +605,7 @@ async def run_local_agent(
errors: list[str] = [] errors: list[str] = []
items_processed = 0 items_processed = 0
items_created = 0 items_created = 0
agent_config: dict = config.agent_config or {} agent_config: dict = config.scout_config or {}
processing_tools = _build_processing_tools(config.data_types) processing_tools = _build_processing_tools(config.data_types)
try: try:
@@ -616,7 +616,7 @@ async def run_local_agent(
last_run_at=config.last_run_at, last_run_at=config.last_run_at,
) )
logger.info( logger.info(
"agent_runner: run=%s found %d file(s) after filtering", run_id, len(file_paths) "scout_runner: run=%s found %d file(s) after filtering", run_id, len(file_paths)
) )
if not file_paths: if not file_paths:
@@ -641,7 +641,7 @@ async def run_local_agent(
raw_content: str = file_result.get("content", "") raw_content: str = file_result.get("content", "")
if not raw_content.strip(): if not raw_content.strip():
logger.debug( logger.debug(
"agent_runner: run=%s skipping empty file %r", run_id, file_path "scout_runner: run=%s skipping empty file %r", run_id, file_path
) )
continue continue
@@ -651,16 +651,21 @@ async def run_local_agent(
preprocessed = preprocess(content_type, raw_content) preprocessed = preprocess(content_type, raw_content)
logger.info( logger.info(
"agent_runner: run=%s file=%r content_type=%s clean_len=%d", "scout_runner: run=%s file=%r content_type=%s clean_len=%d",
run_id, file_path, content_type, len(preprocessed.clean_text), run_id, file_path, content_type, len(preprocessed.clean_text),
) )
# ── Phase B: single LLM call ───────────────────────── # ── Phase B: single LLM call ─────────────────────────
extraction_rules = _get_extraction_rules(agent_config, content_type) extraction_rules = _get_extraction_rules(agent_config, content_type)
no_match_behavior = _get_no_match_behavior(agent_config) no_match_behavior = _get_no_match_behavior(agent_config)
global_rules_lines = "\n".join( base_global_rules = list(agent_config.get("global_rules", []))
f"- {r}" for r in agent_config.get("global_rules", []) if "notes" in config.data_types:
base_global_rules.append(
"For notes: when updating an existing note use `propose_note_edit` "
"(type=append/insert/replace) so the user can review AI changes. "
"Only call `update_note` for complete content replacement without review."
) )
global_rules_lines = "\n".join(f"- {r}" for r in base_global_rules)
metadata_section = _format_metadata(preprocessed.metadata) metadata_section = _format_metadata(preprocessed.metadata)
system_prompt = compile_prompt( system_prompt = compile_prompt(
@@ -706,19 +711,19 @@ async def run_local_agent(
projects_block = _format_projects(projects) projects_block = _format_projects(projects)
logger.info( logger.info(
"agent_runner: run=%s file=%r created=%d result=%s", "scout_runner: run=%s file=%r created=%d result=%s",
run_id, file_path, file_created, result_text[:200], run_id, file_path, file_created, result_text[:200],
) )
except Exception as exc: except Exception as exc:
errors.append(f"Error processing '{file_path}': {exc}") errors.append(f"Error processing '{file_path}': {exc}")
logger.error( logger.error(
"agent_runner: run=%s file=%r failed: %s", run_id, file_path, exc "scout_runner: run=%s file=%r failed: %s", run_id, file_path, exc
) )
except Exception as exc: except Exception as exc:
errors.append(f"Agent run failed: {exc}") errors.append(f"Agent run failed: {exc}")
logger.error("agent_runner: run=%s failed: %s", run_id, exc) logger.error("scout_runner: run=%s failed: %s", run_id, exc)
finally: finally:
_running_agents.discard(agent_id) _running_agents.discard(agent_id)
clear_client_executor() clear_client_executor()
@@ -739,7 +744,7 @@ async def run_local_agent(
errors=errors, errors=errors,
) )
logger.info( logger.info(
"agent_runner: run=%s done status=%s processed=%d created=%d errors=%d", "scout_runner: run=%s done status=%s processed=%d created=%d errors=%d",
run_id, run_id,
final_status, final_status,
items_processed, items_processed,
@@ -757,7 +762,7 @@ async def run_local_agent(
}) })
except Exception as exc: except Exception as exc:
logger.warning( logger.warning(
"agent_runner: run=%s failed to send run_complete: %s", run_id, exc "scout_runner: run=%s failed to send run_complete: %s", run_id, exc
) )
@@ -768,8 +773,8 @@ _CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
async def run_cloud_agent( async def run_cloud_agent(
user_id: str, user_id: str,
config: CloudAgentConfig, config: CloudScoutConfig,
run_log: AgentRunLog, run_log: ScoutRunLog,
device_mgr: DeviceConnectionManager, device_mgr: DeviceConnectionManager,
) -> None: ) -> None:
"""Execute a cloud connector agent run end-to-end. """Execute a cloud connector agent run end-to-end.
@@ -792,7 +797,7 @@ async def run_cloud_agent(
# ── 1. Device online check ───────────────────────────────────────── # ── 1. Device online check ─────────────────────────────────────────
if not device_mgr.is_online(user_id): if not device_mgr.is_online(user_id):
logger.info( logger.info(
"agent_runner: skip cloud run=%s — no device online for user=%s", "scout_runner: skip cloud run=%s — no device online for user=%s",
run_id, run_id,
user_id, user_id,
) )
@@ -817,7 +822,7 @@ async def run_cloud_agent(
try: try:
credentials_info = decrypt_token(config.oauth_token_encrypted) credentials_info = decrypt_token(config.oauth_token_encrypted)
except ValueError as exc: except ValueError as exc:
logger.error("agent_runner: failed to decrypt OAuth token for agent %s: %s", config.id, exc) logger.error("scout_runner: failed to decrypt OAuth token for agent %s: %s", config.id, exc)
await _finalize_run( await _finalize_run(
run_log, run_log,
status="error", status="error",
@@ -863,7 +868,7 @@ async def run_cloud_agent(
raw_messages = [] raw_messages = []
except RuntimeError as exc: except RuntimeError as exc:
logger.error( logger.error(
"agent_runner: provider fetch failed for cloud agent %s: %s", config.id, exc "scout_runner: provider fetch failed for cloud agent %s: %s", config.id, exc
) )
await _finalize_run( await _finalize_run(
run_log, run_log,
@@ -876,7 +881,7 @@ async def run_cloud_agent(
return return
logger.info( logger.info(
"agent_runner: cloud agent %s fetched %d item(s) from %s for user=%s", "scout_runner: cloud agent %s fetched %d item(s) from %s for user=%s",
config.id, config.id,
len(raw_messages), len(raw_messages),
config.provider, config.provider,
@@ -936,16 +941,16 @@ async def run_cloud_agent(
new_encrypted = encrypt_token(refreshed) new_encrypted = encrypt_token(refreshed)
async with async_session() as db: async with async_session() as db:
cfg_result = await db.execute( cfg_result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id) select(CloudScoutConfig).where(CloudScoutConfig.id == config.id)
) )
cfg_row = cfg_result.scalar_one_or_none() cfg_row = cfg_result.scalar_one_or_none()
if cfg_row: if cfg_row:
cfg_row.oauth_token_encrypted = new_encrypted cfg_row.oauth_token_encrypted = new_encrypted
await db.commit() await db.commit()
logger.debug("agent_runner: refreshed OAuth token persisted for agent %s", config.id) logger.debug("scout_runner: refreshed OAuth token persisted for agent %s", config.id)
except Exception as exc: except Exception as exc:
logger.warning( logger.warning(
"agent_runner: failed to persist refreshed token for agent %s: %s", "scout_runner: failed to persist refreshed token for agent %s: %s",
config.id, config.id,
exc, exc,
) )
@@ -969,7 +974,7 @@ async def run_cloud_agent(
config_type="cloud", config_type="cloud",
) )
logger.info( logger.info(
"agent_runner: cloud run=%s done status=%s processed=%d created=%d errors=%d", "scout_runner: cloud run=%s done status=%s processed=%d created=%d errors=%d",
run_id, run_id,
final_status, final_status,
items_processed, items_processed,
@@ -991,7 +996,7 @@ async def trigger_pending_runs(
Called as a background task from the device WS endpoint on ``device_hello``. Called as a background task from the device WS endpoint on ``device_hello``.
""" """
logger.info( logger.info(
"agent_runner: pending-run scan skipped for user=%s device=%s (client-owned agent config)", "scout_runner: pending-run scan skipped for user=%s device=%s (client-owned agent config)",
user_id, user_id,
device_id, device_id,
) )
@@ -1002,7 +1007,7 @@ async def trigger_pending_runs(
async def _finalize_run( async def _finalize_run(
run_log: AgentRunLog, run_log: ScoutRunLog,
*, *,
status: str, status: str,
items_processed: int = 0, items_processed: int = 0,
@@ -1026,14 +1031,14 @@ async def _finalize_run(
if update_config_last_run and config_id: if update_config_last_run and config_id:
if config_type == "local": if config_type == "local":
cfg_result = await db.execute( cfg_result = await db.execute(
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id) select(LocalScoutConfig).where(LocalScoutConfig.id == config_id)
) )
cfg = cfg_result.scalar_one_or_none() cfg = cfg_result.scalar_one_or_none()
if cfg: if cfg:
cfg.last_run_at = now cfg.last_run_at = now
elif config_type == "cloud": elif config_type == "cloud":
cfg_result = await db.execute( cfg_result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id) select(CloudScoutConfig).where(CloudScoutConfig.id == config_id)
) )
cfg = cfg_result.scalar_one_or_none() cfg = cfg_result.scalar_one_or_none()
if cfg: if cfg:
@@ -1042,5 +1047,5 @@ async def _finalize_run(
await db.commit() await db.commit()
except Exception as exc: except Exception as exc:
logger.error( logger.error(
"agent_runner: failed to finalize run_log=%s: %s", run_log.id, exc "scout_runner: failed to finalize run_log=%s: %s", run_log.id, exc
) )

View File

@@ -54,6 +54,43 @@ class _SessionBuffer:
with self._lock: with self._lock:
self._store.pop((user_id, session_id), None) self._store.pop((user_id, session_id), None)
def append_system_message(self, user_id: str, session_id: str, text: str) -> None:
"""Append a synthetic system message to the buffer for the given session.
Creates the session slot if it does not yet exist. Used by the
contextual_scope_update handler to inject navigation events without
making an LLM call.
"""
from langchain_core.messages import SystemMessage # noqa: PLC0415
key = (user_id, session_id)
with self._lock:
entry = self._store.get(key)
if entry is None:
msgs: list[BaseMessage] = [SystemMessage(content=text)]
else:
_, existing = entry
msgs = list(existing) + [SystemMessage(content=text)]
capped = msgs[-MAX_MESSAGES_PER_SESSION:]
self._store[key] = (time.monotonic(), capped)
class ContextualBufferProxy:
"""Thin wrapper around _SessionBuffer that closes over user_id + session_id.
Returned by get_session_buffer() so callers can call
``proxy.append_system_message(text)`` without threading user_id/session_id
through every call site.
"""
def __init__(self, buf: "_SessionBuffer", user_id: str, session_id: str) -> None:
self._buf = buf
self._user_id = user_id
self._session_id = session_id
def append_system_message(self, text: str) -> None:
self._buf.append_system_message(self._user_id, self._session_id, text)
# Module-level singleton — same pattern as _pending_states in api/app/api/routes/auth.py # Module-level singleton — same pattern as _pending_states in api/app/api/routes/auth.py
session_buffer = _SessionBuffer() session_buffer = _SessionBuffer()

View File

@@ -8,7 +8,7 @@ blocking the event loop.
Token refresh is handled transparently: when the stored access token has Token refresh is handled transparently: when the stored access token has
expired, ``google.auth.transport.requests.Request`` will use the refresh expired, ``google.auth.transport.requests.Request`` will use the refresh
token to obtain a fresh one. The caller is responsible for persisting token to obtain a fresh one. The caller is responsible for persisting
any refreshed credentials back to ``CloudAgentConfig.oauth_token_encrypted`` any refreshed credentials back to ``CloudScoutConfig.oauth_token_encrypted``
(see ``agent_runner.run_cloud_agent``). (see ``agent_runner.run_cloud_agent``).
Credential dict shape (Google OAuth2): Credential dict shape (Google OAuth2):

View File

@@ -77,8 +77,98 @@ async def _memory_cron_tick() -> None:
_log.warning("memory cron tick: failed: %s", exc) _log.warning("memory cron tick: failed: %s", exc)
async def _scout_cron_tick() -> None:
"""Every-15-min cron: poll enabled cloud scouts (cron-fallback; push is primary).
Skips any scout whose ``last_run_at`` is within the last 5 minutes so
a push notification and the fallback cron don't double-fire within the
same window.
"""
import logging # noqa: PLC0415
import uuid # noqa: PLC0415
from datetime import datetime, timezone # noqa: PLC0415
_log = logging.getLogger(__name__)
_log.info("scout cron tick: starting")
try:
from app.db import async_session # noqa: PLC0415
from app.models import CloudScoutConfig # noqa: PLC0415
from app.scouts.engine import ScoutEngine # noqa: PLC0415
from sqlalchemy import select # noqa: PLC0415
async with async_session() as session:
scouts = (await session.execute(
select(CloudScoutConfig).where(CloudScoutConfig.enabled == True) # noqa: E712
)).scalars().all()
engine = ScoutEngine()
triggered = 0
for scout in scouts:
# Rate-limit guard: push is primary; skip if ran within 5 minutes.
if scout.last_run_at:
elapsed = (datetime.now(tz=timezone.utc) - scout.last_run_at).total_seconds()
if elapsed < 300:
continue
try:
await engine.trigger_scout(uuid.UUID(str(scout.id)))
triggered += 1
except Exception as exc:
_log.warning("scout cron tick: trigger failed scout=%s: %s", scout.id, exc)
_log.info("scout cron tick: done triggered=%d total=%d", triggered, len(scouts))
except Exception as exc:
_log.warning("scout cron tick: failed: %s", exc)
async def _scout_watch_renewal_tick() -> None:
"""Every-24-hour cron: re-issue Gmail users.watch for scouts expiring within 24h.
Handles missing or misconfigured connectors gracefully — logs and continues.
"""
import logging # noqa: PLC0415
from datetime import datetime, timedelta, timezone # noqa: PLC0415
_log = logging.getLogger(__name__)
_log.info("scout watch renewal tick: starting")
try:
from app.db import async_session # noqa: PLC0415
from app.models import CloudScoutConfig # noqa: PLC0415
from app.scouts.connectors.registry import get_connector # noqa: PLC0415
from sqlalchemy import select # noqa: PLC0415
threshold = datetime.now(tz=timezone.utc) + timedelta(hours=24)
renewed = 0
async with async_session() as session:
scouts = (await session.execute(
select(CloudScoutConfig).where(
CloudScoutConfig.enabled == True, # noqa: E712
CloudScoutConfig.provider == "gmail",
CloudScoutConfig.gmail_watch_expires_at <= threshold,
)
)).scalars().all()
for scout in scouts:
try:
connector = get_connector("gmail")
await connector.renew_watch(scout)
renewed += 1
except Exception:
_log.exception("scout watch renewal tick: renew failed scout=%s", scout.id)
await session.commit()
_log.info("scout watch renewal tick: done renewed=%d", renewed)
except Exception as exc:
_log.warning("scout watch renewal tick: failed: %s", exc)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Startup: register source connectors.
from app.scouts.connectors.gmail import GmailConnector # noqa: PLC0415
from app.scouts.connectors.registry import register_connector # noqa: PLC0415
register_connector(GmailConnector())
# Startup: ensure agent tool modules are loaded. # Startup: ensure agent tool modules are loaded.
import app.agents # noqa: F401 import app.agents # noqa: F401
@@ -89,6 +179,14 @@ async def lifespan(app: FastAPI):
scheduler = AsyncIOScheduler() scheduler = AsyncIOScheduler()
scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron") scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron")
scheduler.add_job(_memory_audit_cron_tick, "interval", weeks=1, id="memory_audit_cron") scheduler.add_job(_memory_audit_cron_tick, "interval", weeks=1, id="memory_audit_cron")
scheduler.add_job(
_scout_cron_tick, "interval", minutes=15,
id="scout_cron_tick", replace_existing=True,
)
scheduler.add_job(
_scout_watch_renewal_tick, "interval", hours=24,
id="scout_watch_renewal_tick", replace_existing=True,
)
scheduler.start() scheduler.start()
logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)") logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)")
@@ -124,12 +222,13 @@ def create_app() -> FastAPI:
app.add_middleware(SanitizerMiddleware) app.add_middleware(SanitizerMiddleware)
app.add_middleware(TierRateLimitMiddleware) app.add_middleware(TierRateLimitMiddleware)
from app.api.routes import agents, auth, billing, chat, device_ws, memory from app.api.routes import scouts, auth, billing, chat, device_ws, memory, scout_webhooks
app.include_router(auth.router, prefix="/api/v1") app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1") app.include_router(chat.router, prefix="/api/v1")
app.include_router(billing.router, prefix="/api/v1") app.include_router(billing.router, prefix="/api/v1")
app.include_router(agents.router, prefix="/api/v1") app.include_router(scouts.router, prefix="/api/v1")
app.include_router(scout_webhooks.router, prefix="/api/v1")
app.include_router(device_ws.router, prefix="/api/v1") app.include_router(device_ws.router, prefix="/api/v1")
app.include_router(memory.router, prefix="/api/v1") app.include_router(memory.router, prefix="/api/v1")

View File

@@ -1,15 +1,15 @@
"""SQLAlchemy ORM models for all persistent tables. """SQLAlchemy ORM models for all persistent tables.
Only auth, billing, agent config, and memory data live here. Only auth, billing, scout config, and memory data live here.
User content (notes, tasks, etc.) lives exclusively on the client. User content (notes, tasks, etc.) lives exclusively on the client.
Table inventory: Table inventory:
users — account credentials + tier users — account credentials + tier
refresh_tokens — hashed refresh token store refresh_tokens — hashed refresh token store
subscriptions — Stripe subscription records subscriptions — Stripe subscription records
local_agent_configs — per-device batch agent configs local_scout_configs — per-device batch scout configs
cloud_agent_configs — OAuth-backed cloud agent configs cloud_scout_configs — OAuth-backed cloud scout configs
agent_run_logs — execution history for all agents scout_run_logs — execution history for all scouts
memory_core — per-user persistent key/value preferences (encrypted) memory_core — per-user persistent key/value preferences (encrypted)
memory_associative — per-user semantic memory with embeddings (encrypted) memory_associative — per-user semantic memory with embeddings (encrypted)
memory_episodic — per-user session summaries (encrypted) memory_episodic — per-user session summaries (encrypted)
@@ -34,8 +34,10 @@ from sqlalchemy import (
LargeBinary, LargeBinary,
String, String,
Text, Text,
UniqueConstraint,
Uuid, Uuid,
func, func,
text,
) )
from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -158,8 +160,8 @@ class Subscription(Base):
user: Mapped[User] = relationship(back_populates="subscription") user: Mapped[User] = relationship(back_populates="subscription")
class LocalAgentConfig(Base): class LocalScoutConfig(Base):
__tablename__ = "local_agent_configs" __tablename__ = "local_scout_configs"
id: Mapped[str] = mapped_column( id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid Uuid(as_uuid=False), primary_key=True, default=_uuid
@@ -172,7 +174,7 @@ class LocalAgentConfig(Base):
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list) directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list) data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="") prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
agent_config: Mapped[dict | None] = mapped_column(JSON, nullable=True) scout_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list) file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *") schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
@@ -184,17 +186,17 @@ class LocalAgentConfig(Base):
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
) )
run_logs: Mapped[list[AgentRunLog]] = relationship( run_logs: Mapped[list["ScoutRunLog"]] = relationship(
back_populates="local_agent", back_populates="local_scout",
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')", primaryjoin="and_(ScoutRunLog.scout_id == LocalScoutConfig.id, ScoutRunLog.scout_type == 'local')",
foreign_keys="AgentRunLog.agent_id", foreign_keys="ScoutRunLog.scout_id",
cascade="all, delete-orphan", cascade="all, delete-orphan",
overlaps="run_logs,cloud_agent", overlaps="run_logs,cloud_scout",
) )
class CloudAgentConfig(Base): class CloudScoutConfig(Base):
__tablename__ = "cloud_agent_configs" __tablename__ = "cloud_scout_configs"
id: Mapped[str] = mapped_column( id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid Uuid(as_uuid=False), primary_key=True, default=_uuid
@@ -217,52 +219,88 @@ class CloudAgentConfig(Base):
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
) )
auto_trash_spam: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default=text("false"))
gmail_history_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
gmail_watch_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
device_inactivity_pause_days: Mapped[int] = mapped_column(Integer, nullable=False, default=14, server_default="14")
run_logs: Mapped[list[AgentRunLog]] = relationship( run_logs: Mapped[list["ScoutRunLog"]] = relationship(
back_populates="cloud_agent", back_populates="cloud_scout",
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')", primaryjoin="and_(ScoutRunLog.scout_id == CloudScoutConfig.id, ScoutRunLog.scout_type == 'cloud')",
foreign_keys="AgentRunLog.agent_id", foreign_keys="ScoutRunLog.scout_id",
cascade="all, delete-orphan", cascade="all, delete-orphan",
overlaps="run_logs,local_agent", overlaps="run_logs,local_scout",
) )
class AgentRunLog(Base): class ScoutTriageQueue(Base):
__tablename__ = "agent_run_logs" __tablename__ = "scout_triage_queue"
__table_args__ = (
UniqueConstraint("scout_id", "source_msg_ref", name="uq_scout_triage_queue_scout_msg"),
)
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
scout_id: Mapped[str] = mapped_column(Uuid(as_uuid=False), ForeignKey("cloud_scout_configs.id", ondelete="CASCADE"), nullable=False)
source_type: Mapped[str] = mapped_column(String(50), nullable=False)
source_msg_ref: Mapped[str] = mapped_column(String(255), nullable=False)
triage_verdict: Mapped[str] = mapped_column(String(20), nullable=False)
triage_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
status: Mapped[str] = mapped_column(String(20), nullable=False, default="queued", server_default="queued")
triaged_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now())
delivered_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
acked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
class ScoutRunLog(Base):
__tablename__ = "scout_run_logs"
id: Mapped[str] = mapped_column( id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid Uuid(as_uuid=False), primary_key=True, default=_uuid
) )
# Plain string — not a FK because it references either local_agent_configs or cloud_agent_configs # Plain string — not a FK because it references either local_scout_configs or cloud_scout_configs
# depending on agent_type. Query by (agent_id, agent_type) to locate the source config. # depending on scout_type. Query by (scout_id, scout_type) to locate the source config.
agent_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True) scout_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
agent_type: Mapped[str] = mapped_column(AgentTypeEnum, nullable=False) scout_type: Mapped[str] = mapped_column(AgentTypeEnum, nullable=False)
user_id: Mapped[str] = mapped_column( user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
) )
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running") status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0) items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0) items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
errors: Mapped[list | None] = mapped_column(JSON, nullable=True) errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
started_at: Mapped[datetime] = mapped_column( started_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now() DateTime(timezone=True), nullable=False, server_default=func.now()
) )
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
local_agent: Mapped[LocalAgentConfig | None] = relationship( local_scout: Mapped["LocalScoutConfig | None"] = relationship(
back_populates="run_logs", back_populates="run_logs",
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')", primaryjoin="and_(ScoutRunLog.scout_id == LocalScoutConfig.id, ScoutRunLog.scout_type == 'local')",
foreign_keys="AgentRunLog.agent_id", foreign_keys="ScoutRunLog.scout_id",
overlaps="run_logs,cloud_agent", overlaps="run_logs,cloud_scout",
) )
cloud_agent: Mapped[CloudAgentConfig | None] = relationship( cloud_scout: Mapped["CloudScoutConfig | None"] = relationship(
back_populates="run_logs", back_populates="run_logs",
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')", primaryjoin="and_(ScoutRunLog.scout_id == CloudScoutConfig.id, ScoutRunLog.scout_type == 'cloud')",
foreign_keys="AgentRunLog.agent_id", foreign_keys="ScoutRunLog.scout_id",
overlaps="run_logs,local_agent", overlaps="run_logs,local_scout",
) )
class MonthlyTokenUsage(Base):
__tablename__ = "monthly_token_usage"
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
)
year_month: Mapped[str] = mapped_column(String(7), primary_key=True) # 'YYYY-MM'
feature: Mapped[str] = mapped_column(String(64), primary_key=True)
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
# ── Memory models ───────────────────────────────────────────────────────────── # ── Memory models ─────────────────────────────────────────────────────────────

View File

@@ -73,11 +73,9 @@ class WsFrameType(str, Enum):
device_hello = "device_hello" device_hello = "device_hello"
# ── v3 frame types ───────────────────────────────────────────────── # ── v3 frame types ─────────────────────────────────────────────────
home_request = "home_request" home_request = "home_request"
floating_request = "floating_request"
stream_start = "stream_start" stream_start = "stream_start"
stream_text = "stream_text" stream_text = "stream_text"
stream_end = "stream_end" stream_end = "stream_end"
floating_domain = "floating_domain"
data_request = "data_request" data_request = "data_request"
data_response = "data_response" data_response = "data_response"
mutation = "mutation" mutation = "mutation"
@@ -87,6 +85,22 @@ class WsFrameType(str, Enum):
journey_reply = "journey_reply" journey_reply = "journey_reply"
# ── v5 brief frame types ────────────────────────────────────────── # ── v5 brief frame types ──────────────────────────────────────────
brief_request = "brief_request" brief_request = "brief_request"
# ── v6 task brief frame types ─────────────────────────────────────
task_brief_request = "task_brief_request"
# ── v7 folder index frame types ───────────────────────────────────
index_session_start = "index_session_start"
index_file_batch = "index_file_batch"
index_session_cancel = "index_session_cancel"
index_file_result = "index_file_result"
index_session_progress = "index_session_progress"
index_session_done = "index_session_done"
# ── v8 contextual sidebar frame types ────────────────────────────
contextual_request = "contextual_request"
contextual_scope_update = "contextual_scope_update"
contextual_scope_ack = "contextual_scope_ack"
# ── v9 scout proposal frame types ────────────────────────────────
SCOUT_PROPOSAL = "scout_proposal"
SCOUT_PROPOSAL_ACK = "scout_proposal_ack"
class WsToolCall(BaseModel): class WsToolCall(BaseModel):
@@ -136,7 +150,7 @@ class WsDeviceHello(BaseModel):
type: Literal[WsFrameType.device_hello] = WsFrameType.device_hello type: Literal[WsFrameType.device_hello] = WsFrameType.device_hello
device_id: str device_id: str
agent_ids: list[str] = Field(default_factory=list) scout_ids: list[str] = Field(default_factory=list)
@@ -152,13 +166,6 @@ class FormatPrefsModel(BaseModel):
now_iso: str = "" now_iso: str = ""
class WsFloatingScope(BaseModel):
"""Scope for a floating request — narrows the agent to a specific entity."""
type: Literal["task", "project", "note", "timeline"]
id: str | None = None
class WsHomeRequest(BaseModel): class WsHomeRequest(BaseModel):
"""Client → Server: Home chat message.""" """Client → Server: Home chat message."""
@@ -168,15 +175,6 @@ class WsHomeRequest(BaseModel):
format_prefs: FormatPrefsModel | None = None format_prefs: FormatPrefsModel | None = None
class WsFloatingRequest(BaseModel):
"""Client → Server: Floating chat message scoped to an entity."""
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
message: str
scope: WsFloatingScope
format_prefs: FormatPrefsModel | None = None
class WsBriefRequest(BaseModel): class WsBriefRequest(BaseModel):
"""Client → Server: Request a plain-text brief (home or project).""" """Client → Server: Request a plain-text brief (home or project)."""
@@ -209,28 +207,13 @@ class WsStreamEnd(BaseModel):
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
request_id: str request_id: str
error: str | None = None error: str | None = None
mutations: list[dict[str, Any]] | None = None
class WsDomain(BaseModel): # ── Scout Config V2 ───────────────────────────────────────────────────
"""Structured floating domain payload for UI routing decisions."""
type: Literal["task", "timeline", "project", "node"]
id: str | None = None
section: Literal["task", "timeline", "note"] | None = None
class WsFloatingDomain(BaseModel): class ScoutContentTypeConfig(BaseModel):
"""Server → Client: domain determined for a floating request."""
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
request_id: str
domain: WsDomain
# ── Agent Config V2 ───────────────────────────────────────────────────
class ContentTypeConfig(BaseModel):
"""Per-type extraction config produced by the journey chatbot.""" """Per-type extraction config produced by the journey chatbot."""
id: str id: str
@@ -240,34 +223,34 @@ class ContentTypeConfig(BaseModel):
extraction_prompt: str extraction_prompt: str
class AgentConfig(BaseModel): class ScoutConfig(BaseModel):
"""Structured agent configuration (replaces freeform prompt_template).""" """Structured scout configuration (replaces freeform prompt_template)."""
content_types: list[ContentTypeConfig] = [] content_types: list[ScoutContentTypeConfig] = []
global_rules: list[str] = [] global_rules: list[str] = []
data_types: list[str] = [] data_types: list[str] = []
# ── Agent Catalog ───────────────────────────────────────────────────── # ── Scout Catalog ─────────────────────────────────────────────────────
class AgentCatalogItem(BaseModel): class ScoutCatalogItem(BaseModel):
type: str type: str
name: str name: str
description: str description: str
class AgentCreationCheckRequest(BaseModel): class ScoutCreationCheckRequest(BaseModel):
active_agents: int = Field(ge=0, default=0) active_agents: int = Field(ge=0, default=0)
class AgentCreationCheckResponse(BaseModel): class ScoutCreationCheckResponse(BaseModel):
allowed: bool allowed: bool
tier: BillingTier tier: BillingTier
active_agents: int active_agents: int
limit: int limit: int
class AgentTriggerRequest(BaseModel): class ScoutTriggerRequest(BaseModel):
directory: str = Field(min_length=1) directory: str = Field(min_length=1)
device_id: str = Field(default="") device_id: str = Field(default="")
agent_id: str | None = None # FE stable agent ID (electron-store UUID) agent_id: str | None = None # FE stable agent ID (electron-store UUID)
@@ -279,9 +262,9 @@ class AgentTriggerRequest(BaseModel):
last_run_at: int | None = None # epoch ms from FE — enables incremental scanning last_run_at: int | None = None # epoch ms from FE — enables incremental scanning
# ── Agent Run Log ───────────────────────────────────────────────────── # ── Scout Run Log ─────────────────────────────────────────────────────
class AgentRunLogResponse(BaseModel): class ScoutRunLogResponse(BaseModel):
id: str id: str
agent_id: str agent_id: str
agent_type: Literal["local", "cloud"] agent_type: Literal["local", "cloud"]
@@ -295,3 +278,25 @@ class AgentRunLogResponse(BaseModel):
# ── Chatbot Journey ─────────────────────────────────────────────────── # ── Chatbot Journey ───────────────────────────────────────────────────
# ── Scout Proposal Frame Models ───────────────────────────────────────
class ScoutProposalPayload(BaseModel):
id: str
scout_id: str
source_type: str
source_msg_ref: str
raw_subject: str | None = None
raw_snippet: str | None = None
category: Literal["unprocessed"] = "unprocessed"
payload: dict | None = None
class ScoutProposalFrame(BaseModel):
type: Literal[WsFrameType.SCOUT_PROPOSAL]
proposal: ScoutProposalPayload
class ScoutProposalAckFrame(BaseModel):
type: Literal[WsFrameType.SCOUT_PROPOSAL_ACK]
proposal_id: str

73
app/schemas/contextual.py Normal file
View File

@@ -0,0 +1,73 @@
"""Contextual sidebar scope schema and prompt block renderer.
ContextualScope mirrors the TypeScript ContextualScope type sent by the
Electron renderer when the user opens the side chat anchored to a specific
view. The renderer ships camelCase keys; Pydantic's alias_generator maps
them to snake_case Python attributes automatically.
"""
from __future__ import annotations
from typing import Literal, Optional
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel
PageType = Literal[
"timeline",
"tasks",
"projects-list",
"project",
"note",
]
EntityType = Literal["project", "note", "task", "timeline_event"]
class ContextualScope(BaseModel):
"""Scope payload sent by the Electron renderer for contextual chat.
The renderer ships camelCase keys (entityType, entityId, ...). Pydantic's
alias generator maps them to snake_case Python attrs.
"""
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
page: PageType
entity_type: Optional[EntityType] = None
entity_id: Optional[str] = None
entity_name: Optional[str] = None
project_id: Optional[str] = None
char_count: Optional[int] = None
counts: Optional[dict[str, int]] = None
filters: Optional[dict] = None
def render_scope_block(scope: ContextualScope) -> str:
"""Produce a single-paragraph human-readable summary of the current view
for injection into the contextual agent system prompt.
Never emits internal ids — only names. The LLM is told to use names in
prose; ids travel through tool calls.
"""
if scope.entity_type == "project":
c = scope.counts or {}
return (
f"User is viewing the project {scope.entity_name!r}. "
f"{c.get('tasks', 0)} tasks, "
f"{c.get('notes', 0)} notes, "
f"{c.get('milestones', 0)} milestones."
)
if scope.entity_type == "note":
return (
f"User is viewing the note {scope.entity_name!r} "
f"({scope.char_count or 0} characters)."
)
if scope.page == "tasks":
return "User is viewing the global Tasks list (all projects)."
if scope.page == "timeline":
return "User is viewing the global Timeline view."
if scope.page == "projects-list":
return "User is viewing the Projects list."
return f"User is on page {scope.page}."

0
app/scouts/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,56 @@
"""Source connector Protocol and shared item types.
A SourceConnector adapts a third-party data source (Gmail, Slack, ...) to the
shared ScoutEngine interface. Each connector owns:
* how to enumerate new items since the last poll (``list_new``)
* how to fetch a single item's metadata cheaply (``fetch_metadata``)
* how to fetch a single item's full content for in-memory triage
(``fetch_content``) — this content MUST NOT be persisted by the engine
* how to archive/trash an item (``archive``) for spam handling
* optional push-notification setup (``setup_watch`` / ``renew_watch``)
"""
from __future__ import annotations
from datetime import datetime
from typing import Literal, Protocol
from pydantic import BaseModel, Field
class ItemRef(BaseModel):
source_msg_ref: str
received_at: datetime | None = None
class ItemMetadata(BaseModel):
subject: str | None = None
sender: str | None = None
snippet: str | None = None
received_at: datetime | None = None
class ItemContent(BaseModel):
metadata: ItemMetadata
body_text: str
raw_headers: dict[str, str] = Field(default_factory=dict)
class TriageVerdict(BaseModel):
verdict: Literal["relevant", "spam"]
reason: str
confidence: float = Field(ge=0.0, le=1.0)
class SourceConnector(Protocol):
"""Adapter for a third-party data source (Gmail, Slack, ...)."""
source_type: str # e.g. "gmail"
async def list_new(self, scout) -> list[ItemRef]: ...
async def fetch_metadata(self, scout, ref: ItemRef) -> ItemMetadata: ...
async def fetch_content(self, scout, ref: ItemRef) -> ItemContent: ...
async def archive(self, scout, ref: ItemRef) -> None: ...
async def setup_watch(self, scout) -> None: ...
async def renew_watch(self, scout) -> None: ...

View File

@@ -0,0 +1,213 @@
"""Gmail SourceConnector — wraps the existing GmailClient.
Responsibilities:
* list_new: incremental fetch since the scout's stored gmail_history_id
* fetch_metadata: subject + sender + snippet only (Gmail metadata format)
* fetch_content: full body text — transient, never persisted by engine
* archive: move a message to Gmail Trash (recoverable for 30 days)
* setup_watch / renew_watch: Gmail push notifications via Pub/Sub
"""
from __future__ import annotations
import asyncio
import logging
from datetime import datetime, timezone
from app.config.settings import settings
from app.integrations import decrypt_token
from app.scouts.connectors.base import ItemContent, ItemMetadata, ItemRef
logger = logging.getLogger(__name__)
def _extract_plain_text_body(payload: dict) -> str:
"""Recursively walk a Gmail message payload to find text/plain content."""
import base64
mime_type = payload.get("mimeType", "")
if mime_type == "text/plain":
data = payload.get("body", {}).get("data", "")
if data:
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
return ""
if mime_type.startswith("multipart/"):
for part in payload.get("parts", []):
text = _extract_plain_text_body(part)
if text:
return text
# text/html fallback: strip tags rudimentarily if no text/plain part
if mime_type == "text/html":
data = payload.get("body", {}).get("data", "")
if data:
import re
html = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
return re.sub(r"<[^>]+>", " ", html)
return ""
def _get_gmail_service(scout):
"""Return a synchronous Google API client for low-level metadata/history calls."""
from googleapiclient.discovery import build
from google.oauth2.credentials import Credentials
creds_info = decrypt_token(scout.oauth_token_encrypted)
credentials = Credentials(
token=creds_info.get("token"),
refresh_token=creds_info.get("refresh_token"),
token_uri=creds_info.get("token_uri", "https://oauth2.googleapis.com/token"),
client_id=creds_info.get("client_id"),
client_secret=creds_info.get("client_secret"),
scopes=creds_info.get("scopes"),
)
return build("gmail", "v1", credentials=credentials, cache_discovery=False)
class GmailConnector:
source_type = "gmail"
# ── list_new ──────────────────────────────────────────────────────────
async def list_new(self, scout) -> list[ItemRef]:
"""Return new message refs since scout.gmail_history_id.
On first run (gmail_history_id is None/empty), records the current
historyId without backfilling — avoids flooding the user with old mail.
Updates scout.gmail_history_id in-place (caller must persist to DB).
"""
def _sync() -> tuple[list[ItemRef], str | None]:
service = _get_gmail_service(scout)
history_id = scout.gmail_history_id
refs: list[ItemRef] = []
new_history_id = history_id
if history_id:
resp = (
service.users()
.history()
.list(
userId="me",
startHistoryId=history_id,
historyTypes=["messageAdded"],
)
.execute()
)
for entry in resp.get("history", []):
for added in entry.get("messagesAdded", []):
refs.append(ItemRef(source_msg_ref=added["message"]["id"]))
new_history_id = resp.get("historyId", history_id)
else:
# First run: capture baseline history id without backfilling.
profile = service.users().getProfile(userId="me").execute()
new_history_id = profile["historyId"]
return refs, new_history_id
refs, new_history_id = await asyncio.to_thread(_sync)
if new_history_id and new_history_id != scout.gmail_history_id:
scout.gmail_history_id = new_history_id
return refs
# ── fetch_metadata ────────────────────────────────────────────────────
async def fetch_metadata(self, scout, ref: ItemRef) -> ItemMetadata:
"""Fetch subject, sender, snippet only — uses Gmail metadata format (no body)."""
def _sync() -> ItemMetadata:
service = _get_gmail_service(scout)
msg = (
service.users()
.messages()
.get(
userId="me",
id=ref.source_msg_ref,
format="metadata",
metadataHeaders=["Subject", "From", "Date"],
)
.execute()
)
headers = {
h["name"]: h["value"]
for h in msg.get("payload", {}).get("headers", [])
}
return ItemMetadata(
subject=headers.get("Subject"),
sender=headers.get("From"),
snippet=msg.get("snippet"),
received_at=None,
)
return await asyncio.to_thread(_sync)
# ── fetch_content ─────────────────────────────────────────────────────
async def fetch_content(self, scout, ref: ItemRef) -> ItemContent:
"""Fetch full body text for a single message — transient, must not be persisted."""
def _sync() -> ItemContent:
service = _get_gmail_service(scout)
msg = service.users().messages().get(
userId="me", id=ref.source_msg_ref, format="full",
).execute()
headers = {h["name"]: h["value"] for h in msg.get("payload", {}).get("headers", [])}
body_text = _extract_plain_text_body(msg.get("payload", {}))
return ItemContent(
metadata=ItemMetadata(
subject=headers.get("Subject"),
sender=headers.get("From"),
snippet=msg.get("snippet"),
received_at=None,
),
body_text=body_text,
raw_headers=headers,
)
return await asyncio.to_thread(_sync)
# ── archive ───────────────────────────────────────────────────────────
async def archive(self, scout, ref: ItemRef) -> None:
"""Move the message to Gmail Trash (recoverable for 30 days)."""
def _sync() -> None:
service = _get_gmail_service(scout)
service.users().messages().trash(
userId="me", id=ref.source_msg_ref
).execute()
await asyncio.to_thread(_sync)
# ── watch management ──────────────────────────────────────────────────
async def setup_watch(self, scout) -> None:
"""Register a Gmail Pub/Sub push watch for the INBOX label.
Requires ``settings.GMAIL_PUBSUB_TOPIC`` to be set to the full topic
resource name (e.g. ``projects/my-project/topics/gmail-push``).
Logs a warning and returns without error if the topic is not configured.
"""
topic = settings.GMAIL_PUBSUB_TOPIC
if not topic:
logger.warning(
"setup_watch: GMAIL_PUBSUB_TOPIC is not configured — skipping watch setup"
)
return
def _sync() -> None:
service = _get_gmail_service(scout)
request_body = {
"labelIds": ["INBOX"],
"topicName": topic,
}
resp = service.users().watch(userId="me", body=request_body).execute()
scout.gmail_history_id = resp.get("historyId")
expiration_ms = resp.get("expiration")
if expiration_ms:
scout.gmail_watch_expires_at = datetime.fromtimestamp(
int(expiration_ms) / 1000, tz=timezone.utc
)
await asyncio.to_thread(_sync)
async def renew_watch(self, scout) -> None:
"""Renew an existing Gmail Pub/Sub watch (same as setup_watch)."""
await self.setup_watch(scout)

View File

@@ -0,0 +1,32 @@
"""Connector registry — single source of truth for source_type -> connector."""
from __future__ import annotations
from typing import Any
_CONNECTORS: dict[str, Any] = {}
def register_connector(connector: Any) -> None:
"""Register a SourceConnector instance under its ``source_type``.
Calling twice with the same ``source_type`` replaces the prior entry —
useful for tests and hot-reload, but in production each connector
should be registered exactly once at startup.
"""
if not getattr(connector, "source_type", None):
raise ValueError("Connector must declare a non-empty source_type")
_CONNECTORS[connector.source_type] = connector
def get_connector(source_type: str) -> Any:
"""Return the registered connector for ``source_type`` or raise KeyError."""
try:
return _CONNECTORS[source_type]
except KeyError as exc:
raise KeyError(f"No connector registered for source_type {source_type!r}") from exc
def _reset_for_tests() -> None:
"""Clear the registry — for use in pytest fixtures only."""
_CONNECTORS.clear()

270
app/scouts/engine.py Normal file
View File

@@ -0,0 +1,270 @@
"""ScoutEngine — orchestrates triage, queueing, and delivery for cloud scouts.
Triage flow per scout:
1. Resolve scout config from the DB.
2. Skip if device hasn't connected within ``device_inactivity_pause_days``.
3. Ask the connector to ``list_new`` — fresh items since last poll.
4. For each item:
- skip if already in the queue (idempotent on (scout_id, source_msg_ref))
- fetch the full content via the connector (transient, never persisted)
- run the triage LLM call → relevant | spam
- spam + auto_trash_spam → connector.archive
- relevant → INSERT scout_triage_queue row
5. Update scout.last_run_at.
Delivery flow on Electron WS reconnect:
- drain ``status='queued'`` rows for the user
- fetch metadata-only for each (subject + snippet)
- send a ``scout_proposal`` frame
- flip status to ``delivered`` on ack
"""
from __future__ import annotations
import logging
import uuid
from datetime import datetime, timedelta, timezone
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback
from app.core.llm import get_llm
from app.db import async_session
from app.models import CloudScoutConfig, ScoutTriageQueue
from app.scouts.connectors.base import ItemContent, ItemRef, TriageVerdict
from app.scouts.connectors.registry import get_connector
logger = logging.getLogger(__name__)
QUEUE_TTL_DAYS = 30
class ScoutEngine:
def __init__(self, session_factory=None) -> None:
self._session_factory = session_factory or async_session
async def trigger_scout(self, scout_id: uuid.UUID) -> None:
async with self._session_factory() as session:
scout = await session.get(CloudScoutConfig, str(scout_id))
if scout is None:
logger.warning("trigger_scout: no such scout id=%s", scout_id)
return
if not scout.enabled:
return
# Device-inactivity pause check is a simple heuristic on last_run_at —
# the device-online signal lives in the DeviceConnectionManager and is
# consulted at delivery time. For triage, we only check that the
# configured pause threshold isn't suppressing the run.
connector = get_connector(scout.provider)
try:
refs = await connector.list_new(scout)
except Exception:
logger.exception("scout %s: list_new failed", scout.id)
return
for ref in refs:
await self._process_item(session, scout, connector, ref)
scout.last_run_at = datetime.now(tz=timezone.utc)
await session.commit()
async def _process_item(
self,
session,
scout: CloudScoutConfig,
connector,
ref: ItemRef,
) -> None:
# Idempotency check
existing = await session.execute(
select(ScoutTriageQueue.id).where(
ScoutTriageQueue.scout_id == scout.id,
ScoutTriageQueue.source_msg_ref == ref.source_msg_ref,
)
)
if existing.first() is not None:
return
try:
content = await connector.fetch_content(scout, ref)
except Exception:
logger.exception("scout %s: fetch_content failed for %s", scout.id, ref.source_msg_ref)
return
try:
verdict = await self._triage_llm(scout, content)
except Exception:
logger.exception("scout %s: triage_llm failed for %s", scout.id, ref.source_msg_ref)
return
if verdict.verdict == "spam":
if scout.auto_trash_spam:
try:
await connector.archive(scout, ref)
except Exception:
logger.exception("scout %s: archive failed for %s", scout.id, ref.source_msg_ref)
return
now = datetime.now(tz=timezone.utc)
row = ScoutTriageQueue(
id=str(uuid.uuid4()),
user_id=scout.user_id,
scout_id=scout.id,
source_type=connector.source_type,
source_msg_ref=ref.source_msg_ref,
triage_verdict=verdict.verdict,
triage_reason=verdict.reason,
status="queued",
triaged_at=now,
expires_at=now + timedelta(days=QUEUE_TTL_DAYS),
)
session.add(row)
try:
# Use a savepoint so an IntegrityError on race doesn't poison the
# outer session — works on both PostgreSQL (SAVEPOINT) and SQLite.
async with session.begin_nested():
await session.flush()
except IntegrityError:
# Race: another worker inserted between our SELECT and INSERT.
# The unique constraint did its job; safe to ignore.
logger.debug(
"scout %s: idempotent skip for %s (race on unique constraint)",
scout.id,
ref.source_msg_ref,
)
async def deliver_pending(self, user_id: uuid.UUID, ws) -> None:
"""Drain status='queued' rows for user, send scout_proposal WS frames, flip to 'delivered'."""
from app.scouts.connectors.base import ItemRef # noqa: PLC0415
async with self._session_factory() as session:
rows = (await session.execute(
select(ScoutTriageQueue).where(
ScoutTriageQueue.user_id == str(user_id),
ScoutTriageQueue.status == "queued",
)
)).scalars().all()
for row in rows:
try:
connector = get_connector(row.source_type)
except KeyError:
logger.warning("deliver_pending: no connector for %s", row.source_type)
continue
scout = await session.get(CloudScoutConfig, row.scout_id)
if scout is None:
continue
try:
meta = await connector.fetch_metadata(scout, ItemRef(source_msg_ref=row.source_msg_ref))
except Exception:
logger.exception("deliver_pending: fetch_metadata failed")
continue
payload = {
"type": "scout_proposal",
"proposal": {
"id": row.id,
"scout_id": row.scout_id,
"source_type": row.source_type,
"source_msg_ref": row.source_msg_ref,
"raw_subject": meta.subject,
"raw_snippet": meta.snippet,
"category": "unprocessed",
"payload": None,
},
}
await ws.send_json(payload)
row.status = "delivered"
row.delivered_at = datetime.now(tz=timezone.utc)
await session.commit()
async def ack_proposal(self, proposal_id: str) -> None:
"""Flip a delivered proposal to acked. Idempotent — no-op if already acked."""
async with self._session_factory() as session:
row = await session.get(ScoutTriageQueue, proposal_id)
if row is None:
return
row.status = "acked"
row.acked_at = datetime.now(tz=timezone.utc)
await session.commit()
async def _triage_llm(self, scout: CloudScoutConfig, content: ItemContent) -> TriageVerdict:
"""Call the scout-triage-system Langfuse prompt to classify an item as relevant or spam.
Uses gpt-4o-mini with JSON mode. Wraps the LLM call in a Langfuse generation
observation when Langfuse is configured.
"""
import json # noqa: PLC0415
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
_TRIAGE_FALLBACK = (
"You are a triage classifier for an executive-assistant scout that watches a "
"{source_type} feed.\n"
'The scout\'s purpose is: "{scout_purpose}".\n\n'
"Given one item, decide whether it is RELEVANT (worth surfacing to the user as a "
"potential task / event / note / project) or SPAM (advertising, mass marketing, "
"phishing, bulk notifications with no actionable content).\n\n"
"Item:\n"
" - Subject: {item_subject}\n"
" - From: {item_sender}\n"
" - Body (truncated): {item_body_truncated_2k}\n\n"
'Return JSON only, matching this schema:\n'
' {{"verdict": "relevant" | "spam", "reason": <short string>, "confidence": <0..1>}}\n\n'
"Be conservative on \"spam\" — if a message could plausibly be a personal/work "
"email, mark it relevant."
)
template, prompt_obj = get_prompt_or_fallback("scout-triage-system", _TRIAGE_FALLBACK)
body_trunc = (content.body_text or "")[:2000]
variables = dict(
source_type=scout.provider,
scout_purpose=scout.prompt_template or "",
item_subject=content.metadata.subject or "",
item_sender=content.metadata.sender or "",
item_body_truncated_2k=body_trunc,
)
if prompt_obj is not None:
try:
system_text = prompt_obj.compile(**variables)
if isinstance(system_text, list):
system_text = "\n".join(
m.get("content", "") for m in system_text if isinstance(m, dict)
)
except Exception as exc:
logger.warning("scout triage: compile failed: %s", exc)
system_text = template.replace("{{source_type}}", variables["source_type"]) \
.replace("{{scout_purpose}}", variables["scout_purpose"]) \
.replace("{{item_subject}}", variables["item_subject"]) \
.replace("{{item_sender}}", variables["item_sender"]) \
.replace("{{item_body_truncated_2k}}", variables["item_body_truncated_2k"])
else:
system_text = template.format(**variables)
llm = get_llm(model="gpt-4o-mini", temperature=0)
llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined]
messages = [
SystemMessage(content=system_text),
HumanMessage(content="Classify this item."),
]
lf = get_langfuse()
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="scout-triage",
model="gpt-4o-mini",
prompt=prompt_obj,
input=messages,
) as gen:
response = await llm_json.ainvoke(messages)
gen.update(output=response.content, usage=extract_usage(response))
else:
response = await llm_json.ainvoke(messages)
data = json.loads(response.content)
return TriageVerdict(**data)

View File

@@ -39,3 +39,5 @@ lxml>=5.0.0
PyYAML>=6.0.0 PyYAML>=6.0.0
apscheduler>=3.10.0 apscheduler>=3.10.0
ruff>=0.8.0 ruff>=0.8.0
pypdf>=4.0
python-docx>=1.1

View File

@@ -17,6 +17,8 @@ from jose import jwt
from sqlalchemy import StaticPool, event from sqlalchemy import StaticPool, event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy import select
from app.config.settings import settings from app.config.settings import settings
from app.db import Base, get_session from app.db import Base, get_session
from app.main import app from app.main import app
@@ -134,6 +136,38 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"} return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
# ── Convenience aliases and per-tier user fixtures ────────────────────
@pytest_asyncio.fixture
async def db(db_session: AsyncSession) -> AsyncSession:
"""Alias for db_session — used by folder quota tests."""
return db_session
@pytest_asyncio.fixture
async def test_user_free(db_session: AsyncSession):
"""Return the seeded free-tier User row."""
result = await db_session.execute(
select(User).where(User.id == TEST_USER_IDS["free"])
)
return result.scalar_one()
@pytest_asyncio.fixture
async def test_user_power(db_session: AsyncSession):
"""Return the seeded power-tier User row."""
result = await db_session.execute(
select(User).where(User.id == TEST_USER_IDS["power"])
)
return result.scalar_one()
@pytest.fixture
def auth_headers_free() -> dict[str, str]:
"""Authorization header for the seeded free-tier user."""
return auth_header("free")
# ── CLI options ─────────────────────────────────────────────────────── # ── CLI options ───────────────────────────────────────────────────────
def pytest_addoption(parser): def pytest_addoption(parser):

View File

@@ -35,7 +35,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import yaml import yaml
from app.core.agent_runner import ( from app.core.scout_runner import (
_format_metadata, _format_metadata,
_format_projects, _format_projects,
_get_extraction_rules, _get_extraction_rules,
@@ -44,7 +44,7 @@ from app.core.agent_runner import (
) )
from app.core.device_manager import DeviceConnectionManager from app.core.device_manager import DeviceConnectionManager
from app.core.langfuse_client import get_langfuse from app.core.langfuse_client import get_langfuse
from app.models import AgentRunLog, LocalAgentConfig from app.models import ScoutRunLog, LocalScoutConfig
from tests.conftest import TEST_USER_IDS from tests.conftest import TEST_USER_IDS
# ── Constants ───────────────────────────────────────────────────────────── # ── Constants ─────────────────────────────────────────────────────────────
@@ -127,8 +127,8 @@ def _make_config(
agent_config: dict | None = None, agent_config: dict | None = None,
directory: str = "/emails", directory: str = "/emails",
device_id: str = "dev-001", device_id: str = "dev-001",
) -> LocalAgentConfig: ) -> LocalScoutConfig:
return LocalAgentConfig( return LocalScoutConfig(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
user_id=_USER_ID, user_id=_USER_ID,
device_id=device_id, device_id=device_id,
@@ -136,7 +136,7 @@ def _make_config(
directory_paths=[directory], directory_paths=[directory],
data_types=["tasks", "notes", "timelines"], data_types=["tasks", "notes", "timelines"],
prompt_template="", prompt_template="",
agent_config=agent_config or _AGENT_CONFIG, scout_config=agent_config or _AGENT_CONFIG,
file_extensions=[".html", ".eml"], file_extensions=[".html", ".eml"],
schedule_cron="0 */6 * * *", schedule_cron="0 */6 * * *",
enabled=True, enabled=True,
@@ -144,11 +144,11 @@ def _make_config(
) )
def _make_run_log(agent_id: str) -> AgentRunLog: def _make_run_log(agent_id: str) -> ScoutRunLog:
return AgentRunLog( return ScoutRunLog(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
agent_id=agent_id, scout_id=agent_id,
agent_type="local", scout_type="local",
user_id=_USER_ID, user_id=_USER_ID,
status="running", status="running",
started_at=datetime.now(timezone.utc), started_at=datetime.now(timezone.utc),
@@ -271,7 +271,7 @@ async def test_2_9_device_offline():
run_log = _make_run_log(config.id) run_log = _make_run_log(config.id)
mgr = _make_manager(online=False) mgr = _make_manager(online=False)
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin: with patch("app.core.scout_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
await run_local_agent(_USER_ID, config, run_log, mgr) await run_local_agent(_USER_ID, config, run_log, mgr)
_, kwargs = mock_fin.call_args _, kwargs = mock_fin.call_args
@@ -295,8 +295,8 @@ async def test_2_10_empty_file():
projects=[_PROJECTS["alpha"]], projects=[_PROJECTS["alpha"]],
) )
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \ with patch("app.core.scout_runner._make_agent_executor", return_value=executor), \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin: patch("app.core.scout_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
await run_local_agent(_USER_ID, config, run_log, mgr) await run_local_agent(_USER_ID, config, run_log, mgr)
_, kwargs = mock_fin.call_args _, kwargs = mock_fin.call_args
@@ -326,9 +326,9 @@ async def test_2_8_items_created_count():
_tool_calls_out.extend(["create_task", "create_note", "update_task"]) _tool_calls_out.extend(["create_task", "create_note", "update_task"])
return "Done." return "Done."
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \ with patch("app.core.scout_runner._make_agent_executor", return_value=executor), \
patch("app.core.agent_runner._run_agent_with_tools", side_effect=mock_run_agent), \ patch("app.core.scout_runner._run_agent_with_tools", side_effect=mock_run_agent), \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin: patch("app.core.scout_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
await run_local_agent(_USER_ID, config, run_log, mgr) await run_local_agent(_USER_ID, config, run_log, mgr)
_, kwargs = mock_fin.call_args _, kwargs = mock_fin.call_args
@@ -377,8 +377,8 @@ async def test_eval_runner(runner_case, pytestconfig):
) if lf else nullcontext() ) if lf else nullcontext()
with obs_ctx as obs: with obs_ctx as obs:
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \ with patch("app.core.scout_runner._make_agent_executor", return_value=executor), \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin: patch("app.core.scout_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
await run_local_agent(_USER_ID, config, run_log, mgr) await run_local_agent(_USER_ID, config, run_log, mgr)
_, kwargs = mock_fin.call_args _, kwargs = mock_fin.call_args

View File

@@ -0,0 +1,52 @@
import pytest
from app.schemas.contextual import ContextualScope, render_scope_block
def test_render_project_scope():
scope = ContextualScope(
page="project",
entity_type="project",
entity_id="p1",
entity_name="Acme Q3 launch",
counts={"tasks": 12, "notes": 4, "milestones": 3},
)
block = render_scope_block(scope)
assert "Acme Q3 launch" in block
assert "12 tasks" in block
assert "4 notes" in block
assert "3 milestones" in block
assert "p1" not in block
def test_render_list_scope_no_entity():
scope = ContextualScope(page="tasks", entity_type=None)
block = render_scope_block(scope)
assert "tasks" in block.lower()
assert "None" not in block
def test_render_note_scope_includes_char_count():
scope = ContextualScope(
page="note",
entity_type="note",
entity_id="n1",
entity_name="Meeting 14 May",
project_id="p1",
char_count=4280,
)
block = render_scope_block(scope)
assert "Meeting 14 May" in block
assert "4280" in block or "4,280" in block
def test_parses_camelcase_payload_from_renderer():
payload = {
"page": "project",
"entityType": "project",
"entityId": "p1",
"entityName": "Acme",
"counts": {"tasks": 5, "notes": 1, "milestones": 2},
}
scope = ContextualScope.model_validate(payload)
assert scope.entity_id == "p1"
assert scope.entity_name == "Acme"

View File

@@ -0,0 +1,44 @@
"""Tests for contextual WS frame handlers.
These tests only exercise the new handler functions in device_ws.py and do
not depend on litellm or the full deep_agent import chain. They monkeypatch
run_contextual_stream so no LLM call is made.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
@pytest.mark.asyncio
async def test_handle_contextual_scope_update_appends_system_message_no_llm(monkeypatch):
"""_handle_contextual_scope_update must:
- call append_system_message on the session buffer
- send a contextual_scope_ack back on the socket
- make no LLM call
"""
from app.api.routes import device_ws
ws = AsyncMock()
buffer = MagicMock()
buffer.append_system_message = MagicMock()
payload = {
"type": "contextual_scope_update",
"session_id": "s1",
"scope": {
"page": "project",
"entityType": "project",
"entityId": "p1",
"entityName": "Acme",
"counts": {"tasks": 1, "notes": 0, "milestones": 0},
},
}
monkeypatch.setattr(device_ws, "get_session_buffer", lambda *a, **kw: buffer)
await device_ws._handle_contextual_scope_update(ws, "user1", payload)
ws.send_text.assert_awaited_once()
import json
sent = json.loads(ws.send_text.await_args.args[0])
assert sent["type"] == "contextual_scope_ack"
assert sent["session_id"] == "s1"
buffer.append_system_message.assert_called_once()

View File

@@ -12,11 +12,8 @@ from langchain_core.messages import AIMessage, ToolMessage
from app.core.deep_agent import ( from app.core.deep_agent import (
_build_system_prompt, _build_system_prompt,
_datetime_context_injection, _datetime_context_injection,
_infer_floating_domain,
_normalize_tagged_list_lines, _normalize_tagged_list_lines,
_request_context_block, _request_context_block,
run_floating,
run_floating_stream,
run_home, run_home,
) )
@@ -75,57 +72,6 @@ async def test_run_home_uses_mocked_tool_result():
assert "Mock Task" in out assert "Mock Task" in out
@pytest.mark.asyncio
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
fake_llm = _FakeLLM()
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
):
events = []
async for event in run_floating_stream(
"user-1",
"show me timeline updates",
{"scope": {"type": "timeline", "id": "tl-1"}},
):
events.append(event)
assert events[0] == (
"floating_domain",
{"type": "timeline", "id": "tl-1", "section": None},
)
# _run_single_agent_stream uses ainvoke (not astream); the final token is
# the second LLM response which echoes the tool result.
token_events = [e for e in events if e[0] == "token"]
assert token_events, "Expected at least one token event"
combined = "".join(str(e[1]) for e in token_events)
assert "Mock Task" in combined
@pytest.mark.asyncio
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
class _ClassifierOnlyLLM:
async def ainvoke(self, _messages):
return AIMessage(
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
)
with patch("app.core.deep_agent.get_agent_llm", return_value=_ClassifierOnlyLLM()):
domain = await _infer_floating_domain(
"Quali sono i miei task per il progetto X",
{
"scope": {"type": "timeline"},
"resolved_project_id": "213213-312321-312312-421321",
},
)
assert domain == {
"type": "project",
"id": "213213-312321-312312-421321",
"section": "task",
}
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines(): def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
raw = ( raw = (
"Certo!\n\n" "Certo!\n\n"
@@ -162,139 +108,6 @@ def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_
assert "<timeline>[tl-future]</timeline>" not in out assert "<timeline>[tl-future]</timeline>" not in out
@pytest.mark.asyncio
async def test_run_floating_strips_xml_like_tags_from_final_text():
fake_llm = _FakeLLM()
async def _fake_run_single_agent(**_kwargs):
return (
"Hai 1 task:\\n"
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
)
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
):
text, _domain = await run_floating(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
)
assert "<task>" not in text
assert "</task>" not in text
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text
@pytest.mark.asyncio
async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
fake_llm = _FakeLLM()
async def _fake_stream(**_kwargs):
yield "token", "Hai 1 task:\\n"
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
):
events = []
async for event in run_floating_stream(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
):
events.append(event)
token_events = [str(data) for event_type, data in events if event_type == "token"]
combined = "".join(token_events)
assert "<task>" not in combined
assert "</task>" not in combined
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined
@pytest.mark.asyncio
async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty():
class _NoChunkLLM:
def __init__(self) -> None:
self.calls = 0
def bind_tools(self, _tools):
return self
async def ainvoke(self, _messages):
self.calls += 1
if self.calls == 1:
return AIMessage(
content="",
tool_calls=[
{
"id": "call-1",
"name": "list_tasks",
"args": {},
}
],
)
return AIMessage(content="No notes found.")
async def astream(self, _messages):
if False:
yield None
with patch("app.core.deep_agent.get_agent_llm", return_value=_NoChunkLLM()), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
):
events = []
async for event in run_floating_stream(
"user-1",
"quali sono le note?",
{"scope": {"type": "note"}},
):
events.append(event)
assert events[0][0] == "floating_domain"
assert ("token", "No notes found.") in events
@pytest.mark.asyncio
async def test_run_floating_returns_fallback_when_sanitization_would_empty_text():
fake_llm = _FakeLLM()
async def _fake_run_single_agent(**_kwargs):
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
):
text, _domain = await run_floating(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
)
assert text == "No results found."
@pytest.mark.asyncio
async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text():
fake_llm = _FakeLLM()
async def _fake_stream(**_kwargs):
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
):
events = []
async for event in run_floating_stream(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
):
events.append(event)
assert ("token", "No results found.") in events
# ── _datetime_context_injection ──────────────────────────────────────────────── # ── _datetime_context_injection ────────────────────────────────────────────────
def _fp(tz: str, now_iso: str) -> dict: def _fp(tz: str, now_iso: str) -> dict:

View File

@@ -22,7 +22,7 @@ import pytest
from app.core.device_manager import DeviceConnectionManager from app.core.device_manager import DeviceConnectionManager
from app.db import get_session from app.db import get_session
from app.main import app from app.main import app
from app.models import AgentRunLog from app.models import ScoutRunLog
from tests.conftest import TEST_USER_IDS, make_jwt from tests.conftest import TEST_USER_IDS, make_jwt
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -33,9 +33,9 @@ _FREE_UID = TEST_USER_IDS["free"]
_PRO_UID = TEST_USER_IDS["pro"] _PRO_UID = TEST_USER_IDS["pro"]
def _device_hello(device_id: str = "dev-001", agent_ids: list[str] | None = None) -> str: def _device_hello(device_id: str = "dev-001", scout_ids: list[str] | None = None) -> str:
return json.dumps( return json.dumps(
{"type": "device_hello", "device_id": device_id, "agent_ids": agent_ids or []} {"type": "device_hello", "device_id": device_id, "scout_ids": scout_ids or []}
) )
@@ -262,10 +262,10 @@ async def test_mark_runs_disconnected_updates_db(db_session):
user_id = TEST_USER_IDS["free"] user_id = TEST_USER_IDS["free"]
run_log = AgentRunLog( run_log = ScoutRunLog(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
agent_id=str(uuid.uuid4()), scout_id=str(uuid.uuid4()),
agent_type="local", scout_type="local",
user_id=user_id, user_id=user_id,
status="running", status="running",
started_at=datetime.now(timezone.utc), started_at=datetime.now(timezone.utc),
@@ -280,7 +280,7 @@ async def test_mark_runs_disconnected_updates_db(db_session):
# Verify through the same session factory. # Verify through the same session factory.
async with _TestSessionLocal() as s: async with _TestSessionLocal() as s:
result = await s.execute( result = await s.execute(
select(AgentRunLog).where(AgentRunLog.id == run_log.id) select(ScoutRunLog).where(ScoutRunLog.id == run_log.id)
) )
updated = result.scalar_one_or_none() updated = result.scalar_one_or_none()

View File

@@ -0,0 +1,139 @@
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.agents.folder_agent import (
read_project_folder_file,
search_project_folder_file,
)
pytestmark = pytest.mark.asyncio
async def test_happy_path():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "file body", "kind": "text", "totalSize": 9}),
):
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "docs/x.md"})
assert "file body" in out
assert "kind=text" in out
async def test_traversal_rejected():
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "../../etc/passwd"})
assert out == "Access denied"
async def test_absolute_path_rejected():
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "C:\\Windows\\foo"})
assert out == "Access denied"
async def test_missing_file():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "", "kind": "missing", "totalSize": 0}),
):
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "ghost.md"})
assert "not found" in out.lower()
async def test_pagination_signals_more_available():
# Electron returned the first slice, totalSize larger than slice length.
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "first chunk", "kind": "text", "totalSize": 1000}),
):
out = await read_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "big.txt",
"offset": 0,
"length": 11,
})
assert "first chunk" in out
assert "More content available" in out
assert "offset=11" in out
async def test_pdf_extracted_then_sliced(monkeypatch):
from app.agents import folder_agent
monkeypatch.setattr(folder_agent, "_extract_pdf_text", lambda b: "ABC " * 100)
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "JVBERi0xLg==", "kind": "pdf", "totalSize": 12}),
):
out = await read_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "doc.pdf",
"offset": 0,
"length": 8,
})
assert "kind=pdf" in out
assert "ABC ABC " in out
assert "More content available" in out
async def test_image_returns_placeholder():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "iVBORw0K", "kind": "image", "totalSize": 1024}),
):
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "logo.png"})
assert "image" in out.lower()
async def test_search_finds_match_with_context():
body = "alpha\nbeta\nthe needle is here\ngamma\ndelta"
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": body, "kind": "text", "totalSize": len(body)}),
):
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "log.txt",
"query": "needle",
"context_lines": 1,
})
assert "needle" in out
assert "matches=1" in out
# Context lines included
assert "beta" in out
assert "gamma" in out
async def test_search_no_match():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "nothing here", "kind": "text", "totalSize": 12}),
):
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "x.txt",
"query": "zzz",
})
assert "No matches" in out
async def test_search_rejects_traversal():
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "../etc/passwd",
"query": "root",
})
assert out == "Access denied"
async def test_search_image_rejected():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "b64data", "kind": "image", "totalSize": 100}),
):
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "logo.png",
"query": "anything",
})
assert "Cannot search" in out

View File

@@ -0,0 +1,83 @@
"""Folder indexer LLM helpers."""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.core.folder_indexer import summarize_text, summarize_image, IndexResult
pytestmark = pytest.mark.asyncio
async def test_summarize_text_returns_summary_and_tokens():
mock_resp = AsyncMock()
mock_resp.content = "Kickoff notes covering scope and deadlines."
mock_resp.usage_metadata = {"input_tokens": 320, "output_tokens": 18, "total_tokens": 338}
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
result = await summarize_text(content="hello world", ext=".md", name="kickoff.md")
assert isinstance(result, IndexResult)
assert result.summary == "Kickoff notes covering scope and deadlines."
assert result.tokens_used == 338
async def test_summarize_text_truncates_summary_at_500_chars():
mock_resp = AsyncMock()
mock_resp.content = "x" * 1000
mock_resp.usage_metadata = {"total_tokens": 100}
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
result = await summarize_text(content="x", ext=".md", name="x.md")
assert len(result.summary) <= 500
async def test_summarize_image_uses_vision_content_blocks():
mock_resp = AsyncMock()
mock_resp.content = "Final logo on white background."
mock_resp.usage_metadata = {"total_tokens": 500}
captured = {}
async def fake_llm_vision(messages):
captured["messages"] = messages
return mock_resp
with patch("app.core.folder_indexer._llm_vision", new=fake_llm_vision):
result = await summarize_image(image_b64="iVBORw0KG", mime="image/png")
assert "Final logo" in result.summary
assert result.tokens_used == 500
# last message contains an image content block
last = captured["messages"][-1]
assert any(
isinstance(p, dict) and p.get("type") == "image_url"
for p in (last.content if isinstance(last.content, list) else [])
)
async def test_summarize_pdf_extracts_then_summarizes(monkeypatch):
# pypdf.PdfReader returns text from pages
from app.core import folder_indexer
class FakePage:
def extract_text(self): return "PDF page content with project info."
class FakeReader:
pages = [FakePage(), FakePage()]
monkeypatch.setattr(folder_indexer, "PdfReader", lambda buf: FakeReader())
mock_resp = AsyncMock(); mock_resp.content = "Project info doc."; mock_resp.usage_metadata = {"total_tokens": 50}
async def fake_llm(messages): return mock_resp
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
result = await folder_indexer.summarize_pdf(pdf_b64="SGVsbG8=", name="doc.pdf")
assert "Project info" in result.summary
assert result.tokens_used == 50
async def test_summarize_docx_extracts_then_summarizes(monkeypatch):
from app.core import folder_indexer
class FakePara:
def __init__(self, t): self.text = t
class FakeDoc:
paragraphs = [FakePara("Heading"), FakePara("Body paragraph one.")]
monkeypatch.setattr(folder_indexer, "DocxDocument", lambda buf: FakeDoc())
mock_resp = AsyncMock(); mock_resp.content = "Heading and body."; mock_resp.usage_metadata = {"total_tokens": 30}
async def fake_llm(messages): return mock_resp
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
result = await folder_indexer.summarize_docx(docx_b64="UEsDBBQ=", name="doc.docx")
assert result.summary == "Heading and body."

View File

@@ -0,0 +1,94 @@
"""Folder quota helpers."""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from sqlalchemy import select
from app.billing.quota import (
check_folder_quota,
add_token_usage,
QuotaExceeded,
)
from app.models import MonthlyTokenUsage
pytestmark = pytest.mark.asyncio
async def test_check_folder_quota_free_rejects_above_file_cap(db, test_user_free):
with pytest.raises(QuotaExceeded) as exc:
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=500, db=db
)
assert exc.value.reason == "max_files"
async def test_check_folder_quota_free_passes_under_cap(db, test_user_free):
# No raise
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=50, db=db
)
async def test_check_folder_quota_rejects_when_monthly_exhausted(db, test_user_free):
ym = datetime.now(timezone.utc).strftime("%Y-%m")
db.add(MonthlyTokenUsage(
user_id=test_user_free.id, year_month=ym, feature="folder_index", tokens_used=100_000
))
await db.commit()
with pytest.raises(QuotaExceeded) as exc:
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=10, db=db
)
assert exc.value.reason == "monthly_tokens"
async def test_check_folder_quota_power_unlimited(db, test_user_power):
await check_folder_quota(
user_id=test_user_power.id, tier="power", estimated_files=999_999, db=db
)
async def test_add_token_usage_atomic_increment(db, test_user_free):
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=1500, db=db)
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=2500, db=db)
ym = datetime.now(timezone.utc).strftime("%Y-%m")
row = (await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == test_user_free.id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == "folder_index",
)
)).scalar_one()
assert row.tokens_used == 4000
async def test_add_token_usage_returns_exhausted_when_over_cap(db, test_user_free):
result = await add_token_usage(
user_id=test_user_free.id, feature="folder_index", tokens=150_000, db=db, cap=100_000
)
assert result.exhausted is True
assert result.tokens_used == 150_000
def test_quota_check_endpoint_rejects(client, auth_headers_free):
res = client.post(
"/api/v1/billing/quota/check",
json={"feature": "folder_index", "estimated_files": 500},
headers=auth_headers_free,
)
assert res.status_code == 402
body = res.json()
assert body["detail"]["reason"] == "max_files"
def test_quota_check_endpoint_passes(client, auth_headers_free):
res = client.post(
"/api/v1/billing/quota/check",
json={"feature": "folder_index", "estimated_files": 50},
headers=auth_headers_free,
)
assert res.status_code == 200
assert res.json() == {"ok": True}

View File

@@ -1,6 +1,6 @@
"""Tests for Local Agent V2 journey setup (Step 4). """Tests for Local Agent V2 journey setup (Step 4).
Covers the chatbot journey that produces a structured AgentConfig JSON Covers the chatbot journey that produces a structured ScoutConfig JSON
instead of a freeform prompt_template string. instead of a freeform prompt_template string.
Unit tests (no LLM) Unit tests (no LLM)
@@ -16,7 +16,7 @@ Eval test (real LLM + Langfuse scoring)
---------------------------------------- ----------------------------------------
4.1 Journey start explores directory → first reply contains a question 4.1 Journey start explores directory → first reply contains a question
Cases 4.24.5 (multi-turn conversations producing a full AgentConfig) are Cases 4.24.5 (multi-turn conversations producing a full ScoutConfig) are
non-deterministic and tested manually — results tracked in Langfuse. non-deterministic and tested manually — results tracked in Langfuse.
Run: Run:
@@ -37,7 +37,7 @@ from unittest.mock import patch
import pytest import pytest
import yaml import yaml
from app.api.routes.agent_setup import ( from app.api.routes.scout_setup import (
_CONFIG_END, _CONFIG_END,
_CONFIG_START, _CONFIG_START,
_MAX_TURNS, _MAX_TURNS,
@@ -48,7 +48,7 @@ from app.api.routes.agent_setup import (
) )
from app.core.langfuse_client import get_langfuse from app.core.langfuse_client import get_langfuse
from app.core.ws_context import clear_client_executor, set_client_executor from app.core.ws_context import clear_client_executor, set_client_executor
from app.schemas import AgentConfig from app.schemas import ScoutConfig
from tests.conftest import TEST_USER_IDS from tests.conftest import TEST_USER_IDS
# ── Constants ───────────────────────────────────────────────────────────── # ── Constants ─────────────────────────────────────────────────────────────
@@ -179,7 +179,7 @@ def _evaluate_case(case: dict, reply: dict) -> tuple[float, str]:
def test_4_6a_extract_valid_json(): def test_4_6a_extract_valid_json():
"""_extract_agent_config: valid JSON between markers → returns serialised config.""" """_extract_agent_config: valid JSON between markers → returns serialised config."""
config = AgentConfig( config = ScoutConfig(
content_types=[], content_types=[],
global_rules=["No project = no entity"], global_rules=["No project = no entity"],
data_types=["tasks"], data_types=["tasks"],
@@ -187,7 +187,7 @@ def test_4_6a_extract_valid_json():
text = f"Some preamble\n{_CONFIG_START}\n{config.model_dump_json()}\n{_CONFIG_END}\nTrailing" text = f"Some preamble\n{_CONFIG_START}\n{config.model_dump_json()}\n{_CONFIG_END}\nTrailing"
result = _extract_agent_config(text) result = _extract_agent_config(text)
assert result is not None assert result is not None
parsed = AgentConfig.model_validate_json(result) parsed = ScoutConfig.model_validate_json(result)
assert parsed.global_rules == ["No project = no entity"] assert parsed.global_rules == ["No project = no entity"]
@@ -230,7 +230,7 @@ async def test_4_6f_nudge_uses_new_markers():
# Return plain text — no markers — to trigger the nudge path. # Return plain text — no markers — to trigger the nudge path.
return "I still need more information from you." return "I still need more information from you."
from app.api.routes.agent_setup import JourneySession from app.api.routes.scout_setup import JourneySession
fake_session = JourneySession( fake_session = JourneySession(
session_id=session_id, session_id=session_id,
@@ -248,7 +248,7 @@ async def test_4_6f_nudge_uses_new_markers():
_sessions[session_id] = fake_session _sessions[session_id] = fake_session
try: try:
with patch("app.api.routes.agent_setup._call_llm_with_tools", side_effect=_mock_llm): with patch("app.api.routes.scout_setup._call_llm_with_tools", side_effect=_mock_llm):
await handle_journey_message(_USER_ID, { await handle_journey_message(_USER_ID, {
"session_id": session_id, "session_id": session_id,
"message": "one more message to trigger nudge", "message": "one more message to trigger nudge",

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.core.deep_agent import format_folder_manifest, MANIFEST_TOKEN_BUDGET
pytestmark = pytest.mark.asyncio
def test_format_folder_manifest_basic():
manifest = {
"folderPath": "D:\\Acme",
"lastScannedAt": "2h ago",
"files": [
{"relPath": "briefs/kickoff.md", "kind": "text", "summary": "Kickoff notes; scope and deadlines."},
{"relPath": "logos/logo-v3.png", "kind": "image", "summary": "Final logo on white."},
],
}
out = format_folder_manifest(manifest)
assert "<linked_folder>" in out
assert "/briefs/kickoff.md" in out or "briefs/kickoff.md" in out
assert "[text]" in out
assert "[image]" in out
def test_format_folder_manifest_truncates_past_budget():
files = [
{"relPath": f"f{i}.md", "kind": "text", "summary": "x" * 100, "mtimeMs": i}
for i in range(2000)
]
out = format_folder_manifest({"folderPath": "p", "lastScannedAt": "now", "files": files})
assert "more files omitted" in out
# Rough token check
assert len(out) // 4 < MANIFEST_TOKEN_BUDGET + 200
def test_format_folder_manifest_null_returns_empty():
assert format_folder_manifest(None) == ""
assert format_folder_manifest({"files": []}) == ""
async def test_brief_multi_project_manifest_top_5_per_project():
fake_response = [
{
"projectId": "p1", "projectName": "Acme", "folderPath": "/a",
"lastScannedAt": "now",
"files": [
{"relPath": f"f{i}.md", "kind": "text", "summary": "s", "mtimeMs": i}
for i in range(10)
],
},
{
"projectId": "p2", "projectName": "Beta", "folderPath": "/b",
"lastScannedAt": "now",
"files": [{"relPath": "x.md", "kind": "text", "summary": "s", "mtimeMs": 1}],
},
]
with patch(
"app.core.deep_agent.execute_on_client",
new=AsyncMock(return_value={"projects": fake_response}),
):
from app.core.deep_agent import build_brief_multi_project_manifest
out = await build_brief_multi_project_manifest()
# Project 1 has 10 files, only top 5 by mtimeMs should appear
assert out.count("[p1]") <= 5
# Project 2 has 1 file, must appear
assert "[p2]" in out or "Beta" in out

View File

@@ -322,7 +322,7 @@ def test_home_request_calls_memory_middleware(client):
): ):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({ ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-mem", "agent_ids": [] "type": "device_hello", "device_id": "dev-mem", "scout_ids": []
})) }))
ws.send_text(json.dumps({ ws.send_text(json.dumps({
"type": "home_request", "type": "home_request",

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import pytest import pytest
from app.core.output_formatter import StreamFormatter from app.core.output_formatter import StreamFormatter
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText from app.schemas import WsStreamEnd, WsStreamStart, WsStreamText
async def _stream(*events: tuple[str, object]): async def _stream(*events: tuple[str, object]):
@@ -36,29 +36,6 @@ async def test_stream_formatter_text_stream() -> None:
assert isinstance(frames[-1], WsStreamEnd) assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio
async def test_stream_formatter_floating_domain_first() -> None:
formatter = StreamFormatter(request_id="req-2")
frames = await _collect(
formatter,
_stream(
(
"floating_domain",
{"type": "node", "id": "n-1", "section": None},
),
("token", "Summary"),
),
)
assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain.type == "node"
assert frames[0].domain.id == "n-1"
assert isinstance(frames[1], WsStreamStart)
assert isinstance(frames[2], WsStreamText)
assert frames[2].chunk == "Summary"
assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stream_formatter_ignores_unknown_events() -> None: async def test_stream_formatter_ignores_unknown_events() -> None:
formatter = StreamFormatter(request_id="req-3") formatter = StreamFormatter(request_id="req-3")

View File

@@ -0,0 +1,85 @@
"""Tests for run_contextual_stream.
These tests monkeypatch _run_single_agent_stream (the actual internal runner)
rather than the plan's fictional _run_agent_loop, matching the real
deep_agent.py architecture.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.schemas.contextual import ContextualScope
@pytest.mark.asyncio
async def test_run_contextual_stream_includes_scope_block(monkeypatch):
"""run_contextual_stream must inject the scope block into the system prompt
and include get_page_details in the tool list while excluding note-edit tools."""
import app.core.deep_agent as deep_agent
captured = {}
async def fake_stream(
*,
user_id,
system_prompt,
message,
context,
agent_name="agent",
tools=None,
conversation_history=None,
**kwargs,
):
captured["sys"] = system_prompt
captured["tool_names"] = [getattr(t, "name", str(t)) for t in (tools or [])]
captured["agent_name"] = agent_name
# Async generator that yields nothing — still satisfies the protocol.
if False:
yield # pragma: no cover
monkeypatch.setattr(deep_agent, "_run_single_agent_stream", fake_stream)
scope = ContextualScope(
page="project",
entity_type="project",
entity_id="p1",
entity_name="Acme",
counts={"tasks": 1, "notes": 0, "milestones": 0},
)
context = {
"conversation_history": [],
"_debug": {"session_id": "s1"},
}
results = []
async for item in deep_agent.run_contextual_stream(
user_id="user1",
message="hi",
context=context,
scope=scope,
):
results.append(item)
assert "Acme" in captured["sys"], "scope block must appear in system prompt"
assert "Current view" in captured["sys"], "section header must be present"
names = captured["tool_names"]
assert "get_page_details" in names, "get_page_details tool must be included"
# Entity-create tools: at least one of these must be present.
assert any(n in names for n in ("create_task", "create_note", "update_task")), (
"at least one entity-create tool must be present"
)
assert "create_timeline" in names, "create_timeline tool must be included"
# Note edit tools must NOT be exposed.
assert "propose_note_edit" not in names, "propose_note_edit must be excluded"
# Legacy read tools must be excluded — they return shallow snapshots and
# cause the agent to under-answer (see trace 0b46841484ba7d024ed9f8d5ac8b1df0).
assert "list_projects" not in names, "list_projects must be excluded (legacy read)"
assert "get_project" not in names, "get_project must be excluded (legacy read)"
assert "list_tasks" not in names, "list_tasks must be excluded (legacy read)"
assert "get_task" not in names, "get_task must be excluded (legacy read)"
assert "list_notes" not in names, "list_notes must be excluded (legacy read)"
assert "get_note" not in names, "get_note must be excluded (legacy read)"

View File

@@ -4,12 +4,8 @@ import pytest
from pydantic import ValidationError from pydantic import ValidationError
from app.schemas import ( from app.schemas import (
WsDomain,
WsFrameType, WsFrameType,
WsHomeRequest, WsHomeRequest,
WsFloatingDomain,
WsFloatingRequest,
WsFloatingScope,
WsStreamEnd, WsStreamEnd,
WsStreamStart, WsStreamStart,
WsStreamText, WsStreamText,
@@ -22,11 +18,9 @@ from app.schemas import (
def test_v3_frame_types_exist(): def test_v3_frame_types_exist():
v3_types = [ v3_types = [
"home_request", "home_request",
"floating_request",
"stream_start", "stream_start",
"stream_text", "stream_text",
"stream_end", "stream_end",
"floating_domain",
"data_request", "data_request",
"data_response", "data_response",
"mutation", "mutation",
@@ -86,51 +80,6 @@ def test_home_request_requires_message():
WsHomeRequest.model_validate({"type": "home_request"}) WsHomeRequest.model_validate({"type": "home_request"})
# ── WsFloatingRequest ────────────────────────────────────────────────────
def test_floating_request_basic():
frame = WsFloatingRequest(
message="Summarise",
scope=WsFloatingScope(type="task", id="task-123"),
)
assert frame.type == WsFrameType.floating_request
assert frame.scope.type == "task"
assert frame.scope.id == "task-123"
def test_floating_request_scope_without_id():
frame = WsFloatingRequest(
message="Show all",
scope=WsFloatingScope(type="project"),
)
assert frame.scope.id is None
def test_floating_request_serializes():
frame = WsFloatingRequest(
message="Test",
scope=WsFloatingScope(type="note", id="n-1"),
)
data = frame.model_dump()
assert data["type"] == "floating_request"
assert data["scope"]["type"] == "note"
assert data["scope"]["id"] == "n-1"
def test_floating_request_invalid_scope_type():
with pytest.raises(ValidationError):
WsFloatingRequest(
message="X",
scope=WsFloatingScope(type="unknown"), # type: ignore[arg-type]
)
def test_floating_request_requires_scope():
with pytest.raises(ValidationError):
WsFloatingRequest.model_validate({"type": "floating_request", "message": "X"})
# ── WsStreamStart ───────────────────────────────────────────────────── # ── WsStreamStart ─────────────────────────────────────────────────────
@@ -189,51 +138,3 @@ def test_stream_end_deserializes():
assert frame.request_id == "r3" assert frame.request_id == "r3"
# ── WsFloatingDomain ─────────────────────────────────────────────────────
def test_floating_domain_tasks():
frame = WsFloatingDomain(request_id="r1", domain=WsDomain(type="task"))
assert frame.type == WsFrameType.floating_domain
assert frame.domain.type == "task"
def test_floating_domain_valid_domains():
frame = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="project", id="213213-312321-312312-421321", section="task"),
)
assert frame.domain.type == "project"
assert frame.domain.id == "213213-312321-312312-421321"
assert frame.domain.section == "task"
def test_floating_domain_object_valid():
frame = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="project", id="p1", section="task"),
)
assert frame.domain.type == "project"
def test_floating_domain_serializes():
d = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="timeline"),
).model_dump()
assert d == {
"type": "floating_domain",
"request_id": "r1",
"domain": {"type": "timeline", "id": None, "section": None},
}
def test_floating_domain_deserializes():
raw = {
"type": "floating_domain",
"request_id": "r1",
"domain": {"type": "node", "id": "n-1", "section": None},
}
frame = WsFloatingDomain.model_validate(raw)
assert frame.domain.type == "node"
assert frame.domain.id == "n-1"

View File

@@ -0,0 +1,48 @@
"""Tests for the connector registry."""
from __future__ import annotations
import pytest
from app.scouts.connectors.base import ItemRef
from app.scouts.connectors.registry import (
get_connector,
register_connector,
_reset_for_tests,
)
class _DummyConnector:
source_type = "dummy"
async def list_new(self, scout): return []
async def fetch_metadata(self, scout, ref): raise NotImplementedError
async def fetch_content(self, scout, ref): raise NotImplementedError
async def archive(self, scout, ref): raise NotImplementedError
async def setup_watch(self, scout): raise NotImplementedError
async def renew_watch(self, scout): raise NotImplementedError
@pytest.fixture(autouse=True)
def _clean_registry():
_reset_for_tests()
yield
_reset_for_tests()
def test_register_and_get():
c = _DummyConnector()
register_connector(c)
assert get_connector("dummy") is c
def test_unknown_source_raises():
with pytest.raises(KeyError):
get_connector("nope")
def test_double_register_replaces():
a = _DummyConnector()
b = _DummyConnector()
register_connector(a)
register_connector(b)
assert get_connector("dummy") is b

View File

@@ -0,0 +1,48 @@
"""Tests for the SourceConnector base protocol and shared types."""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from app.scouts.connectors.base import (
ItemContent,
ItemMetadata,
ItemRef,
TriageVerdict,
)
def test_item_ref_round_trips_through_pydantic():
ref = ItemRef(source_msg_ref="abc123", received_at=datetime.now(tz=timezone.utc))
parsed = ItemRef.model_validate(ref.model_dump())
assert parsed.source_msg_ref == "abc123"
assert parsed.received_at == ref.received_at
def test_item_metadata_allows_all_optional():
meta = ItemMetadata()
assert meta.subject is None
assert meta.sender is None
assert meta.snippet is None
assert meta.received_at is None
def test_item_content_requires_metadata_and_body():
content = ItemContent(
metadata=ItemMetadata(subject="hi"),
body_text="hello world",
raw_headers={"X-Foo": "bar"},
)
assert content.metadata.subject == "hi"
assert content.body_text == "hello world"
assert content.raw_headers["X-Foo"] == "bar"
def test_triage_verdict_constraints():
v = TriageVerdict(verdict="relevant", reason="contains task language", confidence=0.92)
assert v.verdict == "relevant"
with pytest.raises(ValueError):
TriageVerdict(verdict="meh", reason="x", confidence=0.5) # bad enum value

View File

@@ -0,0 +1,84 @@
"""Tests for GmailConnector."""
from __future__ import annotations
import uuid
from unittest.mock import MagicMock, patch
import pytest
from app.models import CloudScoutConfig
from app.scouts.connectors.base import ItemRef
from app.scouts.connectors.gmail import GmailConnector
def _make_scout():
return CloudScoutConfig(
id=str(uuid.uuid4()),
user_id="00000000-0000-0000-0000-000000000003",
provider="gmail",
name="Inbox",
data_types=[],
prompt_template="",
oauth_token_encrypted="encrypted-blob",
schedule_cron="0 * * * *",
enabled=True,
auto_trash_spam=False,
device_inactivity_pause_days=14,
gmail_history_id="100",
)
@pytest.mark.asyncio
async def test_fetch_metadata_returns_subject_and_snippet():
scout = _make_scout()
conn = GmailConnector()
fake_message = {
"id": "msg-1",
"snippet": "preview text",
"payload": {"headers": [
{"name": "Subject", "value": "Hello"},
{"name": "From", "value": "alice@example.com"},
{"name": "Date", "value": "Wed, 14 May 2026 10:00:00 +0000"},
]},
}
with patch("app.scouts.connectors.gmail._get_gmail_service") as mock_svc:
mock_svc.return_value.users().messages().get().execute.return_value = fake_message
meta = await conn.fetch_metadata(scout, ItemRef(source_msg_ref="msg-1"))
assert meta.subject == "Hello"
assert meta.sender == "alice@example.com"
assert meta.snippet == "preview text"
@pytest.mark.asyncio
async def test_fetch_content_returns_body_text():
import base64
scout = _make_scout()
conn = GmailConnector()
body_data = base64.urlsafe_b64encode(b"hello world").decode()
fake_message = {
"id": "msg-1",
"snippet": "hello world",
"payload": {
"mimeType": "text/plain",
"headers": [
{"name": "Subject", "value": "S"},
{"name": "From", "value": "a@b"},
],
"body": {"data": body_data},
},
}
with patch("app.scouts.connectors.gmail._get_gmail_service") as mock_svc:
mock_svc.return_value.users().messages().get().execute.return_value = fake_message
content = await conn.fetch_content(scout, ItemRef(source_msg_ref="msg-1"))
assert content.body_text == "hello world"
assert content.metadata.subject == "S"
@pytest.mark.asyncio
async def test_archive_calls_trash():
scout = _make_scout()
conn = GmailConnector()
with patch("app.scouts.connectors.gmail._get_gmail_service") as mock_svc:
await conn.archive(scout, ItemRef(source_msg_ref="msg-1"))
mock_svc.return_value.users().messages().trash.assert_called()

270
tests/test_scout_engine.py Normal file
View File

@@ -0,0 +1,270 @@
"""Unit tests for ScoutEngine.trigger_scout / _process_item."""
from __future__ import annotations
import uuid
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock
import pytest
from sqlalchemy import select
from app.models import CloudScoutConfig, ScoutTriageQueue, User, Subscription
from app.scouts.connectors.base import ItemContent, ItemMetadata, ItemRef, TriageVerdict
from app.scouts.connectors.registry import register_connector, _reset_for_tests
from app.scouts.engine import ScoutEngine
from tests.conftest import _TestSessionLocal
def _make_connector(items, content_for):
c = AsyncMock()
# source_type must match the scout.provider ("gmail") so get_connector()
# finds it when the engine calls get_connector(scout.provider).
c.source_type = "gmail"
c.list_new = AsyncMock(return_value=items)
c.fetch_content = AsyncMock(side_effect=lambda scout, ref: content_for[ref.source_msg_ref])
c.archive = AsyncMock()
return c
@pytest.fixture(autouse=True)
def _registry():
_reset_for_tests()
yield
_reset_for_tests()
@pytest.mark.asyncio
async def test_relevant_item_inserted_into_queue(monkeypatch):
user_id = "00000000-0000-0000-0000-000000000003" # power tier seeded in conftest
scout_id = str(uuid.uuid4())
async with _TestSessionLocal() as session:
scout = CloudScoutConfig(
id=scout_id, user_id=user_id, provider="gmail", name="Test",
data_types=[], prompt_template="", schedule_cron="0 * * * *",
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
)
session.add(scout)
await session.commit()
refs = [ItemRef(source_msg_ref="msg-1")]
content = {"msg-1": ItemContent(metadata=ItemMetadata(subject="Hi"), body_text="task tomorrow")}
connector = _make_connector(refs, content)
register_connector(connector)
engine = ScoutEngine(session_factory=_TestSessionLocal)
monkeypatch.setattr(
engine,
"_triage_llm",
AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="task", confidence=0.9)),
)
await engine.trigger_scout(uuid.UUID(scout_id))
async with _TestSessionLocal() as session:
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
assert len(rows) == 1
assert rows[0].source_msg_ref == "msg-1"
assert rows[0].triage_verdict == "relevant"
assert rows[0].status == "queued"
connector.archive.assert_not_awaited()
@pytest.mark.asyncio
async def test_spam_with_auto_trash_archives_and_does_not_queue(monkeypatch):
user_id = "00000000-0000-0000-0000-000000000003"
scout_id = str(uuid.uuid4())
async with _TestSessionLocal() as session:
scout = CloudScoutConfig(
id=scout_id, user_id=user_id, provider="gmail", name="Test",
data_types=[], prompt_template="", schedule_cron="0 * * * *",
enabled=True, auto_trash_spam=True, device_inactivity_pause_days=14,
)
session.add(scout)
await session.commit()
refs = [ItemRef(source_msg_ref="msg-spam")]
content = {"msg-spam": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")}
connector = _make_connector(refs, content)
register_connector(connector)
engine = ScoutEngine(session_factory=_TestSessionLocal)
monkeypatch.setattr(
engine,
"_triage_llm",
AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)),
)
await engine.trigger_scout(uuid.UUID(scout_id))
async with _TestSessionLocal() as session:
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
assert rows == []
connector.archive.assert_awaited_once()
@pytest.mark.asyncio
async def test_spam_without_auto_trash_does_not_archive_and_does_not_queue(monkeypatch):
user_id = "00000000-0000-0000-0000-000000000003"
scout_id = str(uuid.uuid4())
async with _TestSessionLocal() as session:
scout = CloudScoutConfig(
id=scout_id, user_id=user_id, provider="gmail", name="Test",
data_types=[], prompt_template="", schedule_cron="0 * * * *",
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
)
session.add(scout)
await session.commit()
refs = [ItemRef(source_msg_ref="msg-2")]
content = {"msg-2": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")}
connector = _make_connector(refs, content)
register_connector(connector)
engine = ScoutEngine(session_factory=_TestSessionLocal)
monkeypatch.setattr(
engine,
"_triage_llm",
AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)),
)
await engine.trigger_scout(uuid.UUID(scout_id))
async with _TestSessionLocal() as session:
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
assert rows == []
connector.archive.assert_not_awaited()
@pytest.mark.asyncio
async def test_idempotent_replay(monkeypatch):
user_id = "00000000-0000-0000-0000-000000000003"
scout_id = str(uuid.uuid4())
async with _TestSessionLocal() as session:
session.add(CloudScoutConfig(
id=scout_id, user_id=user_id, provider="gmail", name="Test",
data_types=[], prompt_template="", schedule_cron="0 * * * *",
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
))
await session.commit()
refs = [ItemRef(source_msg_ref="msg-3")]
content = {"msg-3": ItemContent(metadata=ItemMetadata(subject="x"), body_text="y")}
connector = _make_connector(refs, content)
register_connector(connector)
engine = ScoutEngine(session_factory=_TestSessionLocal)
monkeypatch.setattr(
engine,
"_triage_llm",
AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="x", confidence=0.5)),
)
await engine.trigger_scout(uuid.UUID(scout_id))
await engine.trigger_scout(uuid.UUID(scout_id))
async with _TestSessionLocal() as session:
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
assert len(rows) == 1, "Replay must not create duplicate queue rows"
@pytest.mark.asyncio
async def test_triage_llm_parses_json_response(monkeypatch):
"""Real _triage_llm path: mock the LLM ainvoke, verify TriageVerdict parsed correctly."""
from unittest.mock import MagicMock # noqa: PLC0415
from app.models import CloudScoutConfig # noqa: PLC0415
scout = CloudScoutConfig(
id=str(uuid.uuid4()),
user_id="00000000-0000-0000-0000-000000000003",
provider="gmail",
name="test-scout",
data_types=[],
prompt_template="watch invoices and project updates",
schedule_cron="0 * * * *",
enabled=True,
auto_trash_spam=False,
device_inactivity_pause_days=14,
)
content = ItemContent(
metadata=ItemMetadata(subject="Invoice 42", sender="billing@acme.com"),
body_text="Payment of €1 200 is due on 2026-06-01. Please confirm receipt.",
)
# Build a fake LangChain response whose .content is valid JSON.
fake_response = MagicMock()
fake_response.content = '{"verdict": "relevant", "reason": "invoice due", "confidence": 0.92}'
fake_response.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
# Fake LLM: .bind() returns self (or another mock with ainvoke).
fake_llm = MagicMock()
fake_llm.bind.return_value = fake_llm
fake_llm.ainvoke = AsyncMock(return_value=fake_response)
# Patch get_llm inside app.scouts.engine so our fake is used.
monkeypatch.setattr("app.scouts.engine.get_llm", lambda **kwargs: fake_llm)
# Disable Langfuse for this test.
monkeypatch.setattr("app.scouts.engine.get_langfuse", lambda: None)
# Use fallback prompt (no Langfuse) — patch get_prompt_or_fallback to return fallback.
monkeypatch.setattr(
"app.scouts.engine.get_prompt_or_fallback",
lambda name, fallback: (fallback, None),
)
engine = ScoutEngine(session_factory=_TestSessionLocal)
verdict = await engine._triage_llm(scout, content)
assert verdict.verdict == "relevant"
assert verdict.reason == "invoice due"
assert abs(verdict.confidence - 0.92) < 1e-6
fake_llm.ainvoke.assert_awaited_once()
@pytest.mark.asyncio
async def test_deliver_pending_sends_one_frame_per_queued_row(monkeypatch):
user_id = "00000000-0000-0000-0000-000000000003"
scout_id = str(uuid.uuid4())
now = datetime.now(tz=timezone.utc)
async with _TestSessionLocal() as session:
session.add(CloudScoutConfig(
id=scout_id, user_id=user_id, provider="gmail", name="Test",
data_types=[], prompt_template="", schedule_cron="0 * * * *",
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
))
for i in range(3):
session.add(ScoutTriageQueue(
id=str(uuid.uuid4()), user_id=user_id, scout_id=scout_id,
source_type="gmail", source_msg_ref=f"msg-{i}",
triage_verdict="relevant", status="queued",
triaged_at=now, expires_at=now + timedelta(days=30),
))
await session.commit()
connector = AsyncMock()
connector.source_type = "gmail"
connector.fetch_metadata = AsyncMock(side_effect=lambda scout, ref: ItemMetadata(
subject=f"sub-{ref.source_msg_ref}", snippet=f"snip-{ref.source_msg_ref}",
))
register_connector(connector)
sent = []
ws = AsyncMock()
ws.send_json = AsyncMock(side_effect=lambda payload: sent.append(payload))
engine = ScoutEngine(session_factory=_TestSessionLocal)
await engine.deliver_pending(uuid.UUID(user_id), ws)
assert len(sent) == 3
assert all(s["type"] == "scout_proposal" for s in sent)
subjects = {s["proposal"]["raw_subject"] for s in sent}
assert subjects == {"sub-msg-0", "sub-msg-1", "sub-msg-2"}
async with _TestSessionLocal() as session:
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
assert all(r.status == "delivered" for r in rows)
assert all(r.delivered_at is not None for r in rows)

106
tests/test_scout_webhook.py Normal file
View File

@@ -0,0 +1,106 @@
"""Tests for the Gmail Pub/Sub webhook route.
Covers:
- Happy path: valid JWT + known user + enabled scout → 204, engine triggered.
- Rejection: invalid JWT → 401.
"""
from __future__ import annotations
import base64
import json
import uuid
from unittest.mock import AsyncMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from app.main import app
from app.models import CloudScoutConfig, User
from tests.conftest import _TestSessionLocal
def _pubsub_payload(email: str, history_id: str) -> dict:
"""Build a minimal Pub/Sub push envelope."""
inner = json.dumps({"emailAddress": email, "historyId": history_id}).encode()
return {
"message": {"data": base64.b64encode(inner).decode(), "messageId": "m1"},
"subscription": "projects/x/subscriptions/gmail-watch-sub",
}
@pytest.mark.asyncio
async def test_webhook_triggers_scout_for_matching_user():
"""204 returned and ScoutEngine.trigger_scout awaited for the matching scout."""
user_id = "00000000-0000-0000-0000-000000000003" # seeded 'power' user
scout_id = str(uuid.uuid4())
# Mutate the seeded user email so the webhook can resolve it,
# and add a cloud scout config for gmail.
async with _TestSessionLocal() as session:
user = await session.get(User, user_id)
user.email = "alice@example.com"
session.add(
CloudScoutConfig(
id=scout_id,
user_id=user_id,
provider="gmail",
name="Inbox",
data_types=[],
prompt_template="",
schedule_cron="0 * * * *",
enabled=True,
auto_trash_spam=False,
device_inactivity_pause_days=14,
)
)
await session.commit()
payload = _pubsub_payload("alice@example.com", "200")
with (
patch(
"app.api.routes.scout_webhooks._verify_pubsub_jwt",
return_value=True,
),
patch(
"app.api.routes.scout_webhooks.async_session",
_TestSessionLocal,
),
patch(
"app.scouts.engine.ScoutEngine.trigger_scout",
new=AsyncMock(),
) as mock_trigger,
):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
resp = await client.post(
"/api/v1/scouts/webhooks/gmail",
json=payload,
headers={"Authorization": "Bearer fake-google-jwt"},
)
assert resp.status_code == 204
mock_trigger.assert_awaited_once_with(uuid.UUID(scout_id))
@pytest.mark.asyncio
async def test_webhook_rejects_unverified_jwt():
"""401 returned when JWT verification fails."""
payload = _pubsub_payload("alice@example.com", "200")
with patch(
"app.api.routes.scout_webhooks._verify_pubsub_jwt",
return_value=False,
):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
resp = await client.post(
"/api/v1/scouts/webhooks/gmail",
json=payload,
headers={"Authorization": "Bearer bogus"},
)
assert resp.status_code == 401

View File

@@ -0,0 +1,196 @@
"""Tests for WS folder index_session handlers (Task 9).
Tests the three handler functions directly with a minimal fake WebSocket so
no real WS connection or LLM call is made.
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
import pytest_asyncio
from app.api.routes.device_ws import (
_handle_index_session_start,
_handle_index_file_batch,
_handle_index_session_cancel,
_index_sessions,
)
from app.billing.quota import add_token_usage
from app.core.folder_indexer import IndexResult
from app.models import MonthlyTokenUsage
from app.schemas import WsFrameType
from tests.conftest import TEST_USER_IDS
pytestmark = pytest.mark.asyncio
USER_ID = TEST_USER_IDS["free"]
POWER_USER_ID = TEST_USER_IDS["power"]
# ── Fake WebSocket ────────────────────────────────────────────────────
class _FakeWebSocket:
"""Minimal WebSocket stand-in that records send_text calls."""
def __init__(self) -> None:
self.sent: list[dict] = []
async def send_text(self, text: str) -> None:
self.sent.append(json.loads(text))
def sent_types(self) -> list[str]:
return [f["type"] for f in self.sent]
# ── Helpers ───────────────────────────────────────────────────────────
def _make_session_id() -> str:
import uuid
return str(uuid.uuid4())
def _fake_summarize_text_factory(summary: str = "A test summary.", tokens: int = 100):
"""Return an AsyncMock that resolves to a fixed IndexResult."""
async def _fake(**kwargs) -> IndexResult:
return IndexResult(summary=summary, tokens_used=tokens)
return _fake
# ── Fixtures ──────────────────────────────────────────────────────────
@pytest_asyncio.fixture(autouse=True)
async def _clean_sessions():
"""Ensure _index_sessions is empty before and after each test."""
_index_sessions.clear()
yield
_index_sessions.clear()
# ── Tests ─────────────────────────────────────────────────────────────
async def test_index_session_happy_path(db_session):
"""start + batch of 2 text files → 2 index_file_result + 1 progress + 1 done(completed)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
# Register session.
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"projectId": "proj-1",
"totalFiles": 2,
})
# Verify session was registered.
assert session_id in _index_sessions
assert _index_sessions[session_id]["total"] == 2
assert _index_sessions[session_id]["processed"] == 0
# No response frames expected for session_start.
assert ws.sent == []
# Send batch of 2 text files — patch summarize_text so no LLM call needed.
with patch(
"app.api.routes.device_ws._handle_index_file_batch.__globals__",
# We patch the module-level function in folder_indexer instead:
) if False else patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory()):
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
# Wire db_session into the context manager.
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_async_session.return_value = mock_cm
await _handle_index_file_batch(ws, USER_ID, {
"sessionId": session_id,
"files": [
{"relPath": "README.md", "kind": "text", "content": "hello", "ext": ".md"},
{"relPath": "notes.txt", "kind": "text", "content": "world", "ext": ".txt"},
],
})
types = ws.sent_types()
# Expect 2 file results + 1 progress + 1 done(completed).
assert types.count(WsFrameType.index_file_result) == 2
assert types.count(WsFrameType.index_session_progress) == 1
assert types.count(WsFrameType.index_session_done) == 1
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "completed"
progress_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_progress)
assert progress_frame["processed"] == 2
assert progress_frame["total"] == 2
# Verify session cleaned up.
assert session_id not in _index_sessions
async def test_index_session_cancel(db_session):
"""start then cancel → index_session_done(cancelled)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"totalFiles": 5,
})
assert session_id in _index_sessions
await _handle_index_session_cancel(ws, {"sessionId": session_id})
types = ws.sent_types()
assert WsFrameType.index_session_done in types
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "cancelled"
# Session should be cleaned up.
assert session_id not in _index_sessions
async def test_index_session_quota_exceeded(db_session):
"""Pre-fill usage to cap → batch one file → index_session_done(quota_exceeded)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
# Pre-fill monthly token usage to the free-tier cap (100_000).
ym = datetime.now(timezone.utc).strftime("%Y-%m")
db_session.add(MonthlyTokenUsage(
user_id=USER_ID,
year_month=ym,
feature="folder_index",
tokens_used=100_000, # free tier cap exactly
))
await db_session.commit()
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"totalFiles": 1,
})
with patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory(tokens=1)):
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_async_session.return_value = mock_cm
await _handle_index_file_batch(ws, USER_ID, {
"sessionId": session_id,
"files": [
{"relPath": "file.md", "kind": "text", "content": "content", "ext": ".md"},
],
})
types = ws.sent_types()
# Should have 1 file result (success) then done(quota_exceeded).
assert WsFrameType.index_file_result in types
assert WsFrameType.index_session_done in types
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "quota_exceeded"
# Session should be cleaned up.
assert session_id not in _index_sessions

View File

@@ -1,6 +1,6 @@
"""Integration tests for the unified WebSocket handler (Step 5). """Integration tests for the unified WebSocket handler (Step 5).
Tests the device WS endpoint with home_request and floating_request frames, Tests the device WS endpoint with home_request frames,
verifying that the correct v3 frame sequence is returned. verifying that the correct v3 frame sequence is returned.
LLM calls are mocked to avoid network dependency. LLM calls are mocked to avoid network dependency.
@@ -34,7 +34,7 @@ def _override_db(db_session):
def _recv_until_end(ws, max_frames: int = 20) -> list[dict]: def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
"""Receive frames until stream_end (or stream_end inside floating flow), or max_frames.""" """Receive frames until stream_end or max_frames."""
frames = [] frames = []
for _ in range(max_frames): for _ in range(max_frames):
raw = ws.receive_text() raw = ws.receive_text()
@@ -49,11 +49,6 @@ async def _mock_home_stream(user_id, message, context):
yield "token", "Hello" yield "token", "Hello"
async def _mock_floating_stream(user_id, message, context):
yield "floating_domain", {"type": "task", "id": None, "section": None}
yield "token", "Here is a summary"
# ── tests ───────────────────────────────────────────────────────────────────── # ── tests ─────────────────────────────────────────────────────────────────────
def test_home_request_produces_stream_frames(client): def test_home_request_produces_stream_frames(client):
@@ -63,7 +58,7 @@ def test_home_request_produces_stream_frames(client):
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_home_stream): with patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_home_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({ ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-1", "agent_ids": [] "type": "device_hello", "device_id": "dev-1", "scout_ids": []
})) }))
ws.send_text(json.dumps({ ws.send_text(json.dumps({
"type": "home_request", "type": "home_request",
@@ -79,33 +74,6 @@ def test_home_request_produces_stream_frames(client):
assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end) assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end)
def test_floating_request_produces_domain_frame(client):
"""floating_request → floating_domain first, then stream_text*, stream_end."""
token = make_jwt("power", user_id=USER_ID)
with patch("app.api.routes.device_ws.run_floating_stream", side_effect=_mock_floating_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
}))
ws.send_text(json.dumps({
"type": "floating_request",
"request_id": "p1",
"message": "Summarize this task",
"scope": {"type": "task", "id": "task-123"},
}))
frames = _recv_until_end(ws)
types = [f["type"] for f in frames]
assert WsFrameType.floating_domain in types
assert WsFrameType.stream_end in types
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
assert domain_frame["domain"]["type"] == "task"
assert domain_frame["request_id"] == "p1"
def test_home_request_request_id_propagated(client): def test_home_request_request_id_propagated(client):
"""request_id in home_request is echoed in all response frames.""" """request_id in home_request is echoed in all response frames."""
token = make_jwt("power", user_id=USER_ID) token = make_jwt("power", user_id=USER_ID)
@@ -117,7 +85,7 @@ def test_home_request_request_id_propagated(client):
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream): with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({ ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-3", "agent_ids": [] "type": "device_hello", "device_id": "dev-3", "scout_ids": []
})) }))
ws.send_text(json.dumps({ ws.send_text(json.dumps({
"type": "home_request", "type": "home_request",
@@ -138,7 +106,7 @@ def test_tool_result_dispatch_silent_on_unknown_id(client):
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.05): with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.05):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({ ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-4", "agent_ids": [] "type": "device_hello", "device_id": "dev-4", "scout_ids": []
})) }))
ws.send_text(json.dumps({ ws.send_text(json.dumps({
"type": "tool_result", "id": "no-such-id", "ok": True "type": "tool_result", "id": "no-such-id", "ok": True