Compare commits
51 Commits
c20c6d7853
...
develop
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0833db239c | ||
|
|
11b31e5814 | ||
|
|
cb274c9728 | ||
|
|
d3497a1908 | ||
|
|
0c0299808c | ||
|
|
d1016fd65a | ||
|
|
c559754532 | ||
|
|
9f21d5ae8f | ||
|
|
699bba3a30 | ||
|
|
1364b9ba37 | ||
|
|
27df8c0a8d | ||
|
|
4933f8055c | ||
|
|
ac33ac1c0d | ||
|
|
fbd308d288 | ||
|
|
105cf52083 | ||
|
|
c2b27d4fb7 | ||
|
|
b92e72b685 | ||
|
|
1ccb0282fe | ||
|
|
1a20c11e86 | ||
|
|
70c19d3064 | ||
|
|
886730b47e | ||
|
|
052c7e3741 | ||
|
|
d63fd5f3b9 | ||
|
|
5e42b2abb1 | ||
|
|
2b71469e86 | ||
|
|
6188ae15b3 | ||
|
|
e1db7cdf06 | ||
|
|
c53f08229c | ||
|
|
3e2d80d5bb | ||
|
|
cc0e258e8c | ||
|
|
12e203e63d | ||
|
|
ffcd7390f0 | ||
|
|
91e880f9d4 | ||
|
|
7d47ca54be | ||
|
|
956fa88853 | ||
|
|
fb2f59ccea | ||
|
|
56dbb7f4cd | ||
|
|
506f517851 | ||
|
|
520c186991 | ||
|
|
582bf27deb | ||
|
|
2aeb453229 | ||
|
|
b7a4edac90 | ||
|
|
822b4cd8b1 | ||
|
|
ab24fc4c91 | ||
|
|
a98e99f7a2 | ||
|
|
a0ff285bcd | ||
|
|
177c1a87dd | ||
|
|
441a4ea05c | ||
|
|
a693a64bf5 | ||
|
|
67562b8092 | ||
|
|
6f4c68b359 |
@@ -56,6 +56,10 @@ LLM_MODEL_CLOUD_PROCESSOR=
|
|||||||
# A small model (e.g. gpt-4o-mini) is sufficient.
|
# A small model (e.g. gpt-4o-mini) is sufficient.
|
||||||
# LLM_MODEL_BRIEF_AGENT=
|
# LLM_MODEL_BRIEF_AGENT=
|
||||||
|
|
||||||
|
# Task-brief-agent — per-task deep research (Stage 1 executive assistant).
|
||||||
|
# Needs tool-use + reasoning; a capable model recommended (e.g. gpt-4o, gemini-2.5-flash).
|
||||||
|
# LLM_MODEL_TASK_BRIEF_AGENT=
|
||||||
|
|
||||||
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
|
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
|
||||||
LLM_MODEL_SETUP_AGENT=
|
LLM_MODEL_SETUP_AGENT=
|
||||||
|
|
||||||
|
|||||||
41
alembic/versions/007_rename_agents_to_scouts.py
Normal file
41
alembic/versions/007_rename_agents_to_scouts.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""Rename agents to scouts.
|
||||||
|
|
||||||
|
Revision ID: 007
|
||||||
|
Revises: d6e3f4a5b6c7
|
||||||
|
Create Date: 2026-05-15
|
||||||
|
|
||||||
|
Renames the entire agents subsystem identifiers to scouts.
|
||||||
|
Pre-1.0 — no data preservation concerns beyond ALTER TABLE rename.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = "007"
|
||||||
|
down_revision: Union[str, None] = "d6e3f4a5b6c7"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Tables
|
||||||
|
op.rename_table("local_agent_configs", "local_scout_configs")
|
||||||
|
op.rename_table("cloud_agent_configs", "cloud_scout_configs")
|
||||||
|
op.rename_table("agent_run_logs", "scout_run_logs")
|
||||||
|
|
||||||
|
# Columns
|
||||||
|
op.alter_column("local_scout_configs", "agent_config", new_column_name="scout_config")
|
||||||
|
op.alter_column("scout_run_logs", "agent_id", new_column_name="scout_id")
|
||||||
|
op.alter_column("scout_run_logs", "agent_type", new_column_name="scout_type")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.alter_column("scout_run_logs", "scout_type", new_column_name="agent_type")
|
||||||
|
op.alter_column("scout_run_logs", "scout_id", new_column_name="agent_id")
|
||||||
|
op.alter_column("local_scout_configs", "scout_config", new_column_name="agent_config")
|
||||||
|
|
||||||
|
op.rename_table("scout_run_logs", "agent_run_logs")
|
||||||
|
op.rename_table("cloud_scout_configs", "cloud_agent_configs")
|
||||||
|
op.rename_table("local_scout_configs", "local_agent_configs")
|
||||||
59
alembic/versions/008_scout_triage_queue.py
Normal file
59
alembic/versions/008_scout_triage_queue.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""Scout triage queue + cloud_scout_configs alterations.
|
||||||
|
|
||||||
|
Revision ID: 008
|
||||||
|
Revises: 007
|
||||||
|
Create Date: 2026-05-16
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = "008"
|
||||||
|
down_revision: Union[str, None] = "007"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"scout_triage_queue",
|
||||||
|
sa.Column("id", sa.Uuid(as_uuid=False), primary_key=True),
|
||||||
|
sa.Column("user_id", sa.Uuid(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||||
|
sa.Column("scout_id", sa.Uuid(as_uuid=False), sa.ForeignKey("cloud_scout_configs.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("source_type", sa.String(50), nullable=False),
|
||||||
|
sa.Column("source_msg_ref", sa.String(255), nullable=False),
|
||||||
|
sa.Column("triage_verdict", sa.String(20), nullable=False),
|
||||||
|
sa.Column("triage_reason", sa.Text, nullable=True),
|
||||||
|
sa.Column("status", sa.String(20), nullable=False, server_default="queued"),
|
||||||
|
sa.Column("triaged_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
|
||||||
|
sa.Column("delivered_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("acked_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.UniqueConstraint("scout_id", "source_msg_ref", name="uq_scout_triage_queue_scout_msg"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_scout_triage_queue_user_status", "scout_triage_queue", ["user_id", "status"])
|
||||||
|
op.create_index(
|
||||||
|
"ix_scout_triage_queue_expires_active",
|
||||||
|
"scout_triage_queue",
|
||||||
|
["expires_at"],
|
||||||
|
postgresql_where=sa.text("status != 'acked'"),
|
||||||
|
)
|
||||||
|
|
||||||
|
op.add_column("cloud_scout_configs", sa.Column("auto_trash_spam", sa.Boolean(), nullable=False, server_default=sa.text("false")))
|
||||||
|
op.add_column("cloud_scout_configs", sa.Column("gmail_history_id", sa.String(64), nullable=True))
|
||||||
|
op.add_column("cloud_scout_configs", sa.Column("gmail_watch_expires_at", sa.DateTime(timezone=True), nullable=True))
|
||||||
|
op.add_column("cloud_scout_configs", sa.Column("device_inactivity_pause_days", sa.Integer(), nullable=False, server_default="14"))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("cloud_scout_configs", "device_inactivity_pause_days")
|
||||||
|
op.drop_column("cloud_scout_configs", "gmail_watch_expires_at")
|
||||||
|
op.drop_column("cloud_scout_configs", "gmail_history_id")
|
||||||
|
op.drop_column("cloud_scout_configs", "auto_trash_spam")
|
||||||
|
|
||||||
|
op.drop_index("ix_scout_triage_queue_expires_active", table_name="scout_triage_queue")
|
||||||
|
op.drop_index("ix_scout_triage_queue_user_status", table_name="scout_triage_queue")
|
||||||
|
op.drop_table("scout_triage_queue")
|
||||||
46
alembic/versions/d6e3f4a5b6c7_folder_index_tables.py
Normal file
46
alembic/versions/d6e3f4a5b6c7_folder_index_tables.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Add token tracking columns for folder integration.
|
||||||
|
|
||||||
|
Revision ID: d6e3f4a5b6c7
|
||||||
|
Revises: 006
|
||||||
|
Create Date: 2026-05-11 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "d6e3f4a5b6c7"
|
||||||
|
down_revision: Union[str, None] = "006"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"agent_run_logs",
|
||||||
|
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"monthly_token_usage",
|
||||||
|
sa.Column("user_id", UUID(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("year_month", sa.String(7), nullable=False),
|
||||||
|
sa.Column("feature", sa.String(64), nullable=False),
|
||||||
|
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
|
||||||
|
sa.PrimaryKeyConstraint("user_id", "year_month", "feature"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_monthly_token_usage_user_month",
|
||||||
|
"monthly_token_usage",
|
||||||
|
["user_id", "year_month"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_monthly_token_usage_user_month", table_name="monthly_token_usage")
|
||||||
|
op.drop_table("monthly_token_usage")
|
||||||
|
op.drop_column("agent_run_logs", "tokens_used")
|
||||||
52
app/agents/client_agent.py
Normal file
52
app/agents/client_agent.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""Client agent — read-only tools for the clients table."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_clients(search: str = "", limit: int = 20) -> str:
|
||||||
|
"""List clients, optionally filtered by a name/email substring search.
|
||||||
|
|
||||||
|
search: optional substring to match against client name or email.
|
||||||
|
limit: max rows to return (default 20).
|
||||||
|
"""
|
||||||
|
filters: dict[str, Any] = {"limit": limit}
|
||||||
|
if search:
|
||||||
|
filters["search"] = search
|
||||||
|
|
||||||
|
result = await execute_on_client(action="select", table="clients", filters=filters)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No clients found."
|
||||||
|
lines = [
|
||||||
|
f"- {r.get('name', '?')} (id: {r.get('id')}, email: {r.get('email', '')}, "
|
||||||
|
f"company: {r.get('company', '')})"
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return f"Found {len(rows)} client(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_client(id: str) -> str:
|
||||||
|
"""Get full details for one client by UUID.
|
||||||
|
|
||||||
|
id: the client's UUID.
|
||||||
|
"""
|
||||||
|
if not id:
|
||||||
|
return "Client id is required."
|
||||||
|
|
||||||
|
result = await execute_on_client(action="get", table="clients", data={"id": id})
|
||||||
|
row = result.get("row") or result.get("rows", [None])[0] if result else None
|
||||||
|
if not row:
|
||||||
|
return f"Client '{id}' not found."
|
||||||
|
return f"Client details:\n{json.dumps(row, ensure_ascii=False, indent=2)}"
|
||||||
|
|
||||||
|
|
||||||
|
CLIENT_TOOLS: list[Any] = [list_clients, get_client]
|
||||||
168
app/agents/folder_agent.py
Normal file
168
app/agents/folder_agent.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""Scoped file-read and search tools for the project folder feature."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.folder_indexer import _extract_docx_text, _extract_pdf_text
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
# Cap returned slice size to keep tool output under control.
|
||||||
|
_MAX_RETURN_CHARS = 50_000
|
||||||
|
_MAX_SEARCH_MATCHES = 20
|
||||||
|
|
||||||
|
|
||||||
|
def _is_unsafe_path(rel: str) -> bool:
|
||||||
|
if not rel:
|
||||||
|
return True
|
||||||
|
norm = rel.replace("\\", "/")
|
||||||
|
if norm.startswith("/"):
|
||||||
|
return True
|
||||||
|
# Windows drive letter
|
||||||
|
if len(rel) >= 2 and rel[1] == ":":
|
||||||
|
return True
|
||||||
|
parts = norm.split("/")
|
||||||
|
return ".." in parts
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_file(project_id: str, relative_path: str, offset: int, length: int) -> dict:
|
||||||
|
"""Return the raw Electron tool_result dict for a file read."""
|
||||||
|
return await execute_on_client(
|
||||||
|
action="read_project_folder_file",
|
||||||
|
data={
|
||||||
|
"projectId": project_id,
|
||||||
|
"relativePath": relative_path,
|
||||||
|
"offset": offset,
|
||||||
|
"length": length,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode(result: dict) -> tuple[str, str, int]:
|
||||||
|
"""Decode a tool_result into (text, kind, total_size). For pdf/docx,
|
||||||
|
extracts text from base64. For images, returns a placeholder string.
|
||||||
|
For text, content is already a sliced utf-8 string.
|
||||||
|
"""
|
||||||
|
kind = result.get("kind", "text")
|
||||||
|
content = result.get("content", "") or ""
|
||||||
|
total = int(result.get("totalSize", 0) or 0)
|
||||||
|
if kind == "image":
|
||||||
|
return ("[Image file — cannot be navigated as text. See manifest summary.]", kind, total)
|
||||||
|
if kind == "pdf":
|
||||||
|
return (_extract_pdf_text(content), kind, total)
|
||||||
|
if kind == "docx":
|
||||||
|
return (_extract_docx_text(content), kind, total)
|
||||||
|
return (content, kind, total)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def read_project_folder_file(
|
||||||
|
project_id: str,
|
||||||
|
relative_path: str,
|
||||||
|
offset: int = 0,
|
||||||
|
length: int = _MAX_RETURN_CHARS,
|
||||||
|
) -> str:
|
||||||
|
"""Read a slice of a file inside the project's linked folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: project ID.
|
||||||
|
relative_path: path relative to the linked folder root.
|
||||||
|
offset: char offset to start reading from (0 = beginning).
|
||||||
|
length: max chars to return. Default 50000. Use smaller values to save tokens.
|
||||||
|
|
||||||
|
Returns text content slice with a header showing position. Header tells you
|
||||||
|
when more content is available; call again with the suggested next offset.
|
||||||
|
|
||||||
|
For PDF / DOCX files the backend extracts text first, then applies offset/length
|
||||||
|
on the extracted text. For images returns a placeholder; navigate with the
|
||||||
|
manifest summary instead.
|
||||||
|
"""
|
||||||
|
if _is_unsafe_path(relative_path):
|
||||||
|
return "Access denied"
|
||||||
|
|
||||||
|
result = await _fetch_file(project_id, relative_path, offset, length)
|
||||||
|
text, kind, total_size = _decode(result)
|
||||||
|
|
||||||
|
if not text and kind in ("missing", "error"):
|
||||||
|
return f"File not found or unreadable: {relative_path}"
|
||||||
|
|
||||||
|
if kind in ("pdf", "docx"):
|
||||||
|
# Backend extracted full text — apply offset/length on chars.
|
||||||
|
sliced = text[offset:offset + length]
|
||||||
|
slice_end = min(offset + length, len(text))
|
||||||
|
header = (
|
||||||
|
f"[file={relative_path} kind={kind} offset={offset} end={slice_end} "
|
||||||
|
f"totalChars={len(text)}]"
|
||||||
|
)
|
||||||
|
if slice_end < len(text):
|
||||||
|
header += f"\n[More content available — call again with offset={slice_end}.]"
|
||||||
|
return header + "\n" + sliced
|
||||||
|
|
||||||
|
if kind == "text":
|
||||||
|
slice_end = offset + len(text)
|
||||||
|
header = (
|
||||||
|
f"[file={relative_path} kind=text offset={offset} end={slice_end} "
|
||||||
|
f"totalBytes={total_size}]"
|
||||||
|
)
|
||||||
|
if slice_end < total_size:
|
||||||
|
header += f"\n[More content available — call again with offset={slice_end}.]"
|
||||||
|
return header + "\n" + text
|
||||||
|
|
||||||
|
# image or unknown
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def search_project_folder_file(
|
||||||
|
project_id: str,
|
||||||
|
relative_path: str,
|
||||||
|
query: str,
|
||||||
|
context_lines: int = 3,
|
||||||
|
) -> str:
|
||||||
|
"""Search a project folder file for a query string (case-insensitive substring).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: project ID.
|
||||||
|
relative_path: path relative to the linked folder root.
|
||||||
|
query: text to search for.
|
||||||
|
context_lines: number of lines of context around each match (default 3).
|
||||||
|
|
||||||
|
Returns matching line ranges with surrounding context and 1-based line numbers.
|
||||||
|
Capped at 20 matches; if more exist the header shows the total.
|
||||||
|
|
||||||
|
Works on text, code, markdown, PDF (extracted), and DOCX (extracted).
|
||||||
|
Images and binary files are not searchable.
|
||||||
|
"""
|
||||||
|
if _is_unsafe_path(relative_path):
|
||||||
|
return "Access denied"
|
||||||
|
if not query:
|
||||||
|
return "Empty query."
|
||||||
|
|
||||||
|
# For text we still need full file; pass length=very large.
|
||||||
|
result = await _fetch_file(project_id, relative_path, offset=0, length=10_000_000)
|
||||||
|
text, kind, _ = _decode(result)
|
||||||
|
|
||||||
|
if not text and kind in ("missing", "error"):
|
||||||
|
return f"File not found or unreadable: {relative_path}"
|
||||||
|
if kind == "image":
|
||||||
|
return "Cannot search inside images."
|
||||||
|
|
||||||
|
lines = text.splitlines()
|
||||||
|
q = query.lower()
|
||||||
|
matches = [i for i, line in enumerate(lines) if q in line.lower()]
|
||||||
|
if not matches:
|
||||||
|
return f"No matches for '{query}' in {relative_path}."
|
||||||
|
|
||||||
|
shown = matches[:_MAX_SEARCH_MATCHES]
|
||||||
|
snippets: list[str] = []
|
||||||
|
for i in shown:
|
||||||
|
start = max(0, i - context_lines)
|
||||||
|
end = min(len(lines), i + context_lines + 1)
|
||||||
|
block = "\n".join(f"{n + 1:5d}: {lines[n]}" for n in range(start, end))
|
||||||
|
snippets.append(block)
|
||||||
|
|
||||||
|
header = f"[file={relative_path} matches={len(matches)} showing={len(shown)} query='{query}']"
|
||||||
|
body = "\n---\n".join(snippets)
|
||||||
|
return header + "\n" + body
|
||||||
|
|
||||||
|
|
||||||
|
FOLDER_TOOLS = [read_project_folder_file, search_project_folder_file]
|
||||||
@@ -1,13 +1,14 @@
|
|||||||
"""Note agent — Markdown note management (list, get, create, update, delete)."""
|
"""Note agent — Markdown note management (list, get, create, update, propose edit)."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.llm import embed
|
from app.core.note_summarizer import generate_note_summary
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
_UUID_RE = re.compile(
|
||||||
@@ -19,9 +20,21 @@ def _is_uuid(value: str) -> bool:
|
|||||||
return bool(_UUID_RE.match(value))
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
|
||||||
|
def _fmt_summary(row: dict) -> str:
|
||||||
|
summary = (row.get("aiSummary") or row.get("ai_summary") or "").strip()
|
||||||
|
if summary:
|
||||||
|
return f" — {summary}"
|
||||||
|
snippet = (row.get("content") or "")[:120].replace("\n", " ").strip()
|
||||||
|
return f" — {snippet}" if snippet else ""
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_notes(project_id: str = "") -> str:
|
async def list_notes(project_id: str = "") -> str:
|
||||||
"""List notes, optionally scoped to a project by project_id."""
|
"""List notes with AI summaries, optionally scoped to a project by project_id.
|
||||||
|
|
||||||
|
Returns id, title, and ai_summary for each note so you can decide which
|
||||||
|
note to read in full with get_note before creating or updating.
|
||||||
|
"""
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
@@ -31,7 +44,7 @@ async def list_notes(project_id: str = "") -> str:
|
|||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
return "No notes found."
|
return "No notes found."
|
||||||
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
|
lines = [f" - [{r['id']}] {r['title']}{_fmt_summary(r)}" for r in rows]
|
||||||
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
@@ -66,14 +79,10 @@ async def create_note(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
# Index the note content in the vector store.
|
note_id: str = row["id"]
|
||||||
vector = await embed(content)
|
# Generate summary asynchronously — fire-and-forget.
|
||||||
await execute_on_client(
|
asyncio.create_task(_refresh_summary(note_id, title, content))
|
||||||
action="vector_upsert",
|
return f"Note created: '{row['title']}' (id: {note_id})."
|
||||||
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
|
|
||||||
vector=vector,
|
|
||||||
)
|
|
||||||
return f"Note created: '{row['title']}' (id: {row['id']})."
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -82,7 +91,8 @@ async def update_note(
|
|||||||
title: str = "",
|
title: str = "",
|
||||||
content: str = "",
|
content: str = "",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update an existing note. Only pass fields that should change.
|
"""Update an existing note directly (no approval required).
|
||||||
|
Use propose_note_edit instead when human review is needed.
|
||||||
note_id: UUID of the note (required)
|
note_id: UUID of the note (required)
|
||||||
If you need to preserve existing content, call get_note first.
|
If you need to preserve existing content, call get_note first.
|
||||||
"""
|
"""
|
||||||
@@ -97,17 +107,63 @@ async def update_note(
|
|||||||
data={"id": note_id, "updates": updates},
|
data={"id": note_id, "updates": updates},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
# Re-index if content changed.
|
|
||||||
if content:
|
if content:
|
||||||
vector = await embed(content)
|
new_title = title or row.get("title", "")
|
||||||
await execute_on_client(
|
asyncio.create_task(_refresh_summary(note_id, new_title, content))
|
||||||
action="vector_upsert",
|
|
||||||
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
|
|
||||||
vector=vector,
|
|
||||||
)
|
|
||||||
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def propose_note_edit(
|
||||||
|
note_id: str,
|
||||||
|
edit_type: str,
|
||||||
|
proposed_content: str,
|
||||||
|
reasoning: str = "",
|
||||||
|
anchor_before: str = "",
|
||||||
|
anchor_text: str = "",
|
||||||
|
agent_id: str = "",
|
||||||
|
run_id: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Propose an AI edit to an existing note, pending human approval.
|
||||||
|
|
||||||
|
Use this instead of update_note when review_required is true.
|
||||||
|
The user will see the proposal highlighted before it is merged.
|
||||||
|
|
||||||
|
note_id: UUID of the target note (required)
|
||||||
|
edit_type: 'append' | 'insert' | 'replace'
|
||||||
|
- append: adds proposed_content at the end of the note
|
||||||
|
- insert: inserts proposed_content immediately after anchor_before text
|
||||||
|
- replace: replaces the first occurrence of anchor_text with proposed_content
|
||||||
|
proposed_content: the new Markdown text to add or substitute (required)
|
||||||
|
reasoning: brief explanation shown to the user (recommended)
|
||||||
|
anchor_before: for 'insert' — the text snippet that precedes the insertion point
|
||||||
|
anchor_text: for 'replace' — the exact text to be replaced
|
||||||
|
agent_id: agent identifier (for traceability)
|
||||||
|
run_id: run identifier (for traceability)
|
||||||
|
"""
|
||||||
|
if edit_type not in ("append", "insert", "replace"):
|
||||||
|
return f"Invalid edit_type '{edit_type}'. Use 'append', 'insert', or 'replace'."
|
||||||
|
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="propose_note_edit",
|
||||||
|
data={
|
||||||
|
"noteId": note_id,
|
||||||
|
"type": edit_type,
|
||||||
|
"proposedContent": proposed_content,
|
||||||
|
"reasoning": reasoning or None,
|
||||||
|
"anchorBefore": anchor_before or None,
|
||||||
|
"anchorText": anchor_text or None,
|
||||||
|
"agentId": agent_id or None,
|
||||||
|
"runId": run_id or None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
edit_id = result.get("id", "?")
|
||||||
|
return (
|
||||||
|
f"Edit proposal created (id: {edit_id}) for note {note_id}. "
|
||||||
|
f"Status: pending user approval."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_note(note_id: str) -> str:
|
async def delete_note(note_id: str) -> str:
|
||||||
"""Delete a note permanently by its UUID."""
|
"""Delete a note permanently by its UUID."""
|
||||||
@@ -115,11 +171,32 @@ async def delete_note(note_id: str) -> str:
|
|||||||
return f"Note {note_id} deleted."
|
return f"Note {note_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
async def _refresh_summary(note_id: str, title: str, content: str) -> None:
|
||||||
|
"""Generate and persist the AI summary for a note. Fire-and-forget."""
|
||||||
|
try:
|
||||||
|
summary = await generate_note_summary(title, content)
|
||||||
|
if summary:
|
||||||
|
await execute_on_client(
|
||||||
|
action="update",
|
||||||
|
table="notes",
|
||||||
|
data={
|
||||||
|
"id": note_id,
|
||||||
|
"updates": {
|
||||||
|
"aiSummary": summary,
|
||||||
|
"aiSummaryUpdatedAt": int(__import__("time").time() * 1000),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # fire-and-forget; errors logged by generate_note_summary
|
||||||
|
|
||||||
|
|
||||||
NOTE_TOOLS: list[Any] = [
|
NOTE_TOOLS: list[Any] = [
|
||||||
list_notes,
|
list_notes,
|
||||||
get_note,
|
get_note,
|
||||||
create_note,
|
create_note,
|
||||||
update_note,
|
update_note,
|
||||||
|
propose_note_edit,
|
||||||
delete_note,
|
delete_note,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
63
app/agents/relations_agent.py
Normal file
63
app/agents/relations_agent.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""Relations agent — read-only tool wrapping MemoryMiddleware.query_relations."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.db import async_session
|
||||||
|
|
||||||
|
# Injected at tool-factory time by _brief_research_tools(); not a module-level global.
|
||||||
|
# Each tool closure captures the user_id bound at factory time.
|
||||||
|
|
||||||
|
|
||||||
|
def make_query_relations_tool(user_id: str, trace_id: str | None = None) -> Any:
|
||||||
|
"""Return a query_relations tool bound to *user_id*."""
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def query_relations(
|
||||||
|
subject_label: str = "",
|
||||||
|
predicate: str = "",
|
||||||
|
object_label: str = "",
|
||||||
|
limit: int = 10,
|
||||||
|
) -> str:
|
||||||
|
"""Query the relational memory graph for entity relationships.
|
||||||
|
|
||||||
|
Returns rows where subject ↔ predicate ↔ object match the given filters.
|
||||||
|
All parameters are optional — omit to retrieve all relations up to limit.
|
||||||
|
|
||||||
|
subject_label: entity label on the left side (e.g. a client name, "Acme Corp").
|
||||||
|
predicate: relationship type (e.g. "mentioned_in", "works_at", "related_to").
|
||||||
|
object_label: entity label on the right side (e.g. a project name, "Website Redesign").
|
||||||
|
limit: max rows to return (default 10).
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info(
|
||||||
|
"relations_agent: query_relations trace=%s user=%s subject=%r predicate=%r object=%r",
|
||||||
|
trace_id or "-", user_id, subject_label, predicate, object_label,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
rows = await memory.query_relations(
|
||||||
|
user_id=user_id,
|
||||||
|
subject=subject_label or None,
|
||||||
|
predicate=predicate or None,
|
||||||
|
object_=object_label or None,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return "No relational memory entries found for the given filters."
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
f"- {r.subject_label} —[{r.predicate}]→ {r.object_label}"
|
||||||
|
+ (f" (confidence: {r.confidence:.2f})" if r.confidence is not None else "")
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return f"Found {len(rows)} relation(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
return query_relations
|
||||||
@@ -1,232 +0,0 @@
|
|||||||
"""Agent routes.
|
|
||||||
|
|
||||||
Backend responsibilities are intentionally minimal:
|
|
||||||
GET /agents/catalog — static catalog for UI display
|
|
||||||
POST /agents/can-create — billing eligibility check
|
|
||||||
POST /agents/trigger — trigger a local agent run
|
|
||||||
|
|
||||||
Agent configuration is owned by the Electron app and is not persisted
|
|
||||||
in backend agent-config tables.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from sqlalchemy import func, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.billing.tier_manager import FEATURES
|
|
||||||
from app.core.agent_runner import is_agent_running, run_local_agent
|
|
||||||
from app.core.device_manager import device_manager
|
|
||||||
from app.db import get_session
|
|
||||||
from app.models import AgentRunLog, LocalAgentConfig
|
|
||||||
from app.schemas import (
|
|
||||||
AgentCatalogItem,
|
|
||||||
AgentCreationCheckRequest,
|
|
||||||
AgentCreationCheckResponse,
|
|
||||||
AgentRunLogResponse,
|
|
||||||
AgentTriggerRequest,
|
|
||||||
UserProfile,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/agents", tags=["agents"])
|
|
||||||
|
|
||||||
|
|
||||||
# ── Datetime helpers ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _dt_ms(dt: datetime) -> int:
|
|
||||||
return int(dt.timestamp() * 1000)
|
|
||||||
|
|
||||||
|
|
||||||
def _dt_ms_opt(dt: datetime | None) -> int | None:
|
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
|
||||||
|
|
||||||
|
|
||||||
def _to_data_types(values: list[str]) -> list[str]:
|
|
||||||
normalize = {
|
|
||||||
"task": "tasks", "tasks": "tasks",
|
|
||||||
"note": "notes", "notes": "notes",
|
|
||||||
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
|
||||||
"project": "projects", "projects": "projects",
|
|
||||||
}
|
|
||||||
seen: set[str] = set()
|
|
||||||
result: list[str] = []
|
|
||||||
for v in values:
|
|
||||||
mapped = normalize.get(v)
|
|
||||||
if mapped and mapped not in seen:
|
|
||||||
seen.add(mapped)
|
|
||||||
result.append(mapped)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
|
||||||
return AgentRunLogResponse(
|
|
||||||
id=log.id,
|
|
||||||
agent_id=log.agent_id,
|
|
||||||
agent_type=log.agent_type, # type: ignore[arg-type]
|
|
||||||
status=log.status, # type: ignore[arg-type]
|
|
||||||
items_processed=log.items_processed,
|
|
||||||
items_created=log.items_created,
|
|
||||||
errors=log.errors or [],
|
|
||||||
started_at=_dt_ms(log.started_at),
|
|
||||||
completed_at=_dt_ms_opt(log.completed_at),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
|
||||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
|
||||||
if limit != -1 and current_count >= limit:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
|
||||||
)
|
|
||||||
return limit
|
|
||||||
|
|
||||||
|
|
||||||
async def _enforce_run_frequency(
|
|
||||||
tier: str,
|
|
||||||
user_id: str,
|
|
||||||
db: AsyncSession,
|
|
||||||
) -> None:
|
|
||||||
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
|
|
||||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
|
||||||
if limit == -1:
|
|
||||||
return # unlimited
|
|
||||||
|
|
||||||
today_start = datetime.now(timezone.utc).replace(
|
|
||||||
hour=0, minute=0, second=0, microsecond=0
|
|
||||||
)
|
|
||||||
result = await db.execute(
|
|
||||||
select(func.count(AgentRunLog.id)).where(
|
|
||||||
AgentRunLog.user_id == user_id,
|
|
||||||
AgentRunLog.started_at >= today_start,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
runs_today: int = result.scalar_one()
|
|
||||||
|
|
||||||
if runs_today >= limit:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Catalog ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/catalog", response_model=list[AgentCatalogItem])
|
|
||||||
async def get_agent_catalog(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> list[AgentCatalogItem]:
|
|
||||||
"""Return the static list of available agent types and their descriptions."""
|
|
||||||
return [
|
|
||||||
AgentCatalogItem(
|
|
||||||
type="local_directory",
|
|
||||||
name="Local Directory Monitor",
|
|
||||||
description="Watches local directories, extracts data from files using AI",
|
|
||||||
),
|
|
||||||
AgentCatalogItem(
|
|
||||||
type="gmail",
|
|
||||||
name="Gmail Connector",
|
|
||||||
description="Scans Gmail inbox, extracts tasks/notes from emails",
|
|
||||||
),
|
|
||||||
AgentCatalogItem(
|
|
||||||
type="teams",
|
|
||||||
name="Microsoft Teams Connector",
|
|
||||||
description="Monitors Teams messages, extracts action items",
|
|
||||||
),
|
|
||||||
AgentCatalogItem(
|
|
||||||
type="outlook",
|
|
||||||
name="Outlook Connector",
|
|
||||||
description="Scans Outlook inbox, extracts tasks/notes",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/can-create", response_model=AgentCreationCheckResponse)
|
|
||||||
async def can_create_agent(
|
|
||||||
body: AgentCreationCheckRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> AgentCreationCheckResponse:
|
|
||||||
"""Check if the user can create one more agent based on billing tier.
|
|
||||||
|
|
||||||
Since configuration is client-owned, the Electron app sends its current
|
|
||||||
active agent count and the backend applies tier limits.
|
|
||||||
"""
|
|
||||||
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
|
|
||||||
allowed = limit == -1 or body.active_agents < limit
|
|
||||||
return AgentCreationCheckResponse(
|
|
||||||
allowed=allowed,
|
|
||||||
tier=current_user.tier,
|
|
||||||
active_agents=body.active_agents,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
|
||||||
async def trigger_agent_run(
|
|
||||||
body: AgentTriggerRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> AgentRunLogResponse:
|
|
||||||
"""Trigger a local agent run using client-provided configuration."""
|
|
||||||
_enforce_agent_limit(current_user.tier, body.active_agents)
|
|
||||||
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
|
||||||
|
|
||||||
last_run_dt = (
|
|
||||||
datetime.fromtimestamp(body.last_run_at / 1000, tz=timezone.utc)
|
|
||||||
if body.last_run_at
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
config = LocalAgentConfig(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=current_user.id,
|
|
||||||
device_id=body.device_id,
|
|
||||||
name="Local Directory Monitor",
|
|
||||||
directory_paths=[body.directory],
|
|
||||||
data_types=_to_data_types(body.what_to_extract),
|
|
||||||
prompt_template=body.custom_agent_prompt or "",
|
|
||||||
agent_config=body.agent_config,
|
|
||||||
file_extensions=[],
|
|
||||||
schedule_cron=body.batch_interval,
|
|
||||||
enabled=True,
|
|
||||||
last_run_at=last_run_dt,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
|
||||||
stable_agent_id = body.agent_id or config.id
|
|
||||||
|
|
||||||
if is_agent_running(stable_agent_id):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
|
||||||
detail="Agent is already running. Only one run per agent is allowed at a time.",
|
|
||||||
)
|
|
||||||
|
|
||||||
run_log = AgentRunLog(
|
|
||||||
agent_id=stable_agent_id,
|
|
||||||
agent_type="local",
|
|
||||||
user_id=current_user.id,
|
|
||||||
status="running",
|
|
||||||
)
|
|
||||||
db.add(run_log)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(run_log)
|
|
||||||
|
|
||||||
run_context = {
|
|
||||||
"type": "agent_batch",
|
|
||||||
"run_id": run_log.id,
|
|
||||||
"agent_id": stable_agent_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
asyncio.create_task(
|
|
||||||
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
|
|
||||||
)
|
|
||||||
|
|
||||||
return _to_run_log_response(run_log)
|
|
||||||
@@ -9,7 +9,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, Request, status
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
@@ -96,3 +96,37 @@ async def list_invoices(
|
|||||||
"""
|
"""
|
||||||
invoices = await stripe_service.list_invoices(current_user.id, db)
|
invoices = await stripe_service.list_invoices(current_user.id, db)
|
||||||
return invoices
|
return invoices
|
||||||
|
|
||||||
|
|
||||||
|
# ── Quota check ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
from app.billing.quota import check_folder_quota, QuotaExceeded # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaCheckRequest(BaseModel):
|
||||||
|
feature: str
|
||||||
|
estimated_files: int
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/quota/check")
|
||||||
|
async def quota_check(
|
||||||
|
payload: QuotaCheckRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict:
|
||||||
|
"""Pre-flight folder quota check. 402 if tier limits would be exceeded."""
|
||||||
|
if payload.feature != "folder_index":
|
||||||
|
raise HTTPException(status_code=400, detail="Unknown feature")
|
||||||
|
try:
|
||||||
|
await check_folder_quota(
|
||||||
|
user_id=current_user.id,
|
||||||
|
tier=current_user.tier,
|
||||||
|
estimated_files=payload.estimated_files,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
except QuotaExceeded as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=402,
|
||||||
|
detail={"reason": exc.reason, "message": str(exc)},
|
||||||
|
)
|
||||||
|
return {"ok": True}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ available during the WebSocket handshake).
|
|||||||
|
|
||||||
Protocol:
|
Protocol:
|
||||||
1. Client connects → JWT validated → connection accepted.
|
1. Client connects → JWT validated → connection accepted.
|
||||||
2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``.
|
2. Client sends ``device_hello`` frame: ``{ type, device_id, scout_ids }``.
|
||||||
3. Backend registers the connection in ``DeviceConnectionManager``.
|
3. Backend registers the connection in ``DeviceConnectionManager``.
|
||||||
4. Session enters message dispatch loop + heartbeat.
|
4. Session enters message dispatch loop + heartbeat.
|
||||||
|
|
||||||
@@ -39,23 +39,31 @@ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
|
||||||
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
|
from app.api.routes.scout_setup import handle_journey_message, handle_journey_start
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.core.agent_runner import trigger_pending_runs
|
from app.scouts.engine import ScoutEngine
|
||||||
|
from app.core.scout_runner import trigger_pending_runs
|
||||||
|
from app.core.scout_session_buffer import session_buffer
|
||||||
from app.core.brief_agent import run_home_brief, run_project_brief
|
from app.core.brief_agent import run_home_brief, run_project_brief
|
||||||
from app.core.deep_agent import run_floating_stream, run_home_stream
|
from app.core.deep_agent import run_contextual_stream, run_home_stream, run_task_brief_research_stream
|
||||||
|
from app.core.output_formatter import extract_canvas_block
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.core.output_formatter import StreamFormatter
|
from app.core.output_formatter import StreamFormatter
|
||||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog
|
from app.models import ScoutRunLog
|
||||||
from app.schemas import WsFrameType, WsStreamEnd
|
from app.schemas import WsFrameType, WsStreamEnd
|
||||||
|
from app.schemas.contextual import ContextualScope, render_scope_block
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
||||||
|
|
||||||
|
# ── v7 folder index session state ─────────────────────────────────────
|
||||||
|
# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled }
|
||||||
|
_index_sessions: dict[str, dict] = {}
|
||||||
|
|
||||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||||
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
||||||
|
|
||||||
@@ -93,7 +101,7 @@ async def device_ws(websocket: WebSocket) -> None:
|
|||||||
if hello.get("type") != WsFrameType.device_hello:
|
if hello.get("type") != WsFrameType.device_hello:
|
||||||
raise ValueError("expected device_hello as first frame")
|
raise ValueError("expected device_hello as first frame")
|
||||||
device_id: str = hello["device_id"]
|
device_id: str = hello["device_id"]
|
||||||
agent_ids: list[str] = hello.get("agent_ids", [])
|
scout_ids: list[str] = hello.get("scout_ids", [])
|
||||||
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
||||||
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
|
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
|
||||||
await websocket.close(code=1008)
|
await websocket.close(code=1008)
|
||||||
@@ -102,15 +110,25 @@ async def device_ws(websocket: WebSocket) -> None:
|
|||||||
# ── 3. Register connection ────────────────────────────────────────
|
# ── 3. Register connection ────────────────────────────────────────
|
||||||
device_manager.register(user_id, device_id, websocket)
|
device_manager.register(user_id, device_id, websocket)
|
||||||
logger.info(
|
logger.info(
|
||||||
"device_ws: connected user=%s device=%s agents=%s",
|
"device_ws: connected user=%s device=%s scouts=%s",
|
||||||
user_id,
|
user_id,
|
||||||
device_id,
|
device_id,
|
||||||
agent_ids,
|
scout_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Trigger any overdue agent runs now that the device is connected.
|
# Trigger any overdue agent runs now that the device is connected.
|
||||||
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
||||||
|
|
||||||
|
# Drain any queued scout proposals and deliver to the client (non-blocking).
|
||||||
|
async def _deliver_pending_safe() -> None:
|
||||||
|
import uuid as _uuid # noqa: PLC0415
|
||||||
|
try:
|
||||||
|
await ScoutEngine().deliver_pending(_uuid.UUID(user_id), websocket)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("scout deliver_pending failed for user %s", user_id)
|
||||||
|
|
||||||
|
asyncio.create_task(_deliver_pending_safe())
|
||||||
|
|
||||||
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
||||||
try:
|
try:
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
@@ -154,16 +172,16 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
_handle_home_request(websocket, user_id, frame)
|
_handle_home_request(websocket, user_id, frame)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif frame_type == WsFrameType.floating_request:
|
|
||||||
asyncio.create_task(
|
|
||||||
_handle_floating_request(websocket, user_id, frame)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.brief_request:
|
elif frame_type == WsFrameType.brief_request:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
_handle_brief_request(websocket, user_id, frame)
|
_handle_brief_request(websocket, user_id, frame)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.task_brief_request:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_task_brief_request(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
elif frame_type == WsFrameType.journey_start:
|
elif frame_type == WsFrameType.journey_start:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
_handle_journey_start(websocket, user_id, frame)
|
_handle_journey_start(websocket, user_id, frame)
|
||||||
@@ -174,6 +192,37 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
_handle_journey_message(websocket, user_id, frame)
|
_handle_journey_message(websocket, user_id, frame)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.index_session_start:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_index_session_start(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.index_file_batch:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_index_file_batch(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.index_session_cancel:
|
||||||
|
await _handle_index_session_cancel(websocket, frame)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.contextual_request:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_contextual_request(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.contextual_scope_update:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_contextual_scope_update(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == "scout_proposal_ack":
|
||||||
|
proposal_id = frame.get("proposal_id")
|
||||||
|
if proposal_id:
|
||||||
|
try:
|
||||||
|
await ScoutEngine().ack_proposal(proposal_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("scout ack_proposal failed for %s", proposal_id)
|
||||||
|
|
||||||
elif frame_type == "pong":
|
elif frame_type == "pong":
|
||||||
# Heartbeat ack — nothing to do, connection is alive.
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
pass
|
pass
|
||||||
@@ -205,11 +254,13 @@ async def _handle_home_request(
|
|||||||
request_id = frame.get("request_id") or str(uuid4())
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
|
||||||
logger.info(
|
logger.info(
|
||||||
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
|
"device_ws: home_request_start user=%s req=%s session=%s project=%s msg=%s",
|
||||||
user_id,
|
user_id,
|
||||||
request_id,
|
request_id,
|
||||||
session_id,
|
session_id,
|
||||||
|
project_id,
|
||||||
message[:200],
|
message[:200],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -234,7 +285,7 @@ async def _handle_home_request(
|
|||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
try:
|
try:
|
||||||
event_stream = run_home_stream(user_id, message, context)
|
event_stream = run_home_stream(user_id, message, context, project_id=project_id)
|
||||||
formatter = StreamFormatter(request_id=request_id)
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
async for ws_frame in formatter.format(event_stream):
|
async for ws_frame in formatter.format(event_stream):
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
@@ -264,26 +315,41 @@ async def _handle_home_request(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _handle_floating_request(
|
# ── v8 Contextual Sidebar Handlers ───────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_buffer(user_id: str, session_id: str, channel: str = "contextual"):
|
||||||
|
"""Return a session-scoped buffer proxy for the given user+session.
|
||||||
|
|
||||||
|
Returns a _ContextualBufferProxy that exposes append_system_message().
|
||||||
|
Defined at module level so tests can monkeypatch it.
|
||||||
|
The channel kwarg is accepted for forward-compatibility.
|
||||||
|
"""
|
||||||
|
from app.core.scout_session_buffer import ContextualBufferProxy # noqa: PLC0415
|
||||||
|
return ContextualBufferProxy(session_buffer, user_id, session_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_contextual_request(
|
||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
frame: dict,
|
frame: dict,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle a floating_request frame — streams FloatingFormatter output back on the socket."""
|
"""Handle a contextual_request frame — runs the contextual agent and streams frames."""
|
||||||
request_id = frame.get("request_id") or str(uuid4())
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
scope: dict = frame.get("scope", {})
|
scope_payload: dict = frame.get("scope", {})
|
||||||
logger.info(
|
logger.info(
|
||||||
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
|
"device_ws: contextual_request_start user=%s req=%s session=%s msg=%s",
|
||||||
user_id,
|
user_id,
|
||||||
request_id,
|
request_id,
|
||||||
session_id,
|
session_id,
|
||||||
json.dumps(scope, ensure_ascii=True)[:200],
|
|
||||||
message[:200],
|
message[:200],
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
scope = ContextualScope.model_validate(scope_payload)
|
||||||
|
|
||||||
|
# Enrich context with memory before the LLM call.
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(
|
memory_context = await memory.enrich_context(
|
||||||
@@ -295,9 +361,8 @@ async def _handle_floating_request(
|
|||||||
|
|
||||||
context: dict = {
|
context: dict = {
|
||||||
"conversation_history": frame.get("conversation_history", []),
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
"scope": scope,
|
|
||||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
||||||
"format_prefs": frame.get("format_prefs"),
|
"format_prefs": frame.get("format_prefs"),
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
**memory_context,
|
**memory_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -305,7 +370,12 @@ async def _handle_floating_request(
|
|||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
try:
|
try:
|
||||||
event_stream = run_floating_stream(user_id, message, context)
|
event_stream = run_contextual_stream(
|
||||||
|
user_id=user_id,
|
||||||
|
message=message,
|
||||||
|
context=context,
|
||||||
|
scope=scope,
|
||||||
|
)
|
||||||
formatter = StreamFormatter(request_id=request_id)
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
async for ws_frame in formatter.format(event_stream):
|
async for ws_frame in formatter.format(event_stream):
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
@@ -313,20 +383,20 @@ async def _handle_floating_request(
|
|||||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(
|
logger.error(
|
||||||
"device_ws: floating_request failed user=%s req=%s: %s",
|
"device_ws: contextual_request failed user=%s req=%s: %s",
|
||||||
user_id, request_id, exc,
|
user_id, request_id, exc,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
clear_client_executor()
|
clear_client_executor()
|
||||||
|
|
||||||
# ── Memory: store episode after response ──────────────────────────
|
# Store episode so the contextual agent can recall prior turns.
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
|
"device_ws: contextual_request_end user=%s req=%s session=%s response_chars=%d",
|
||||||
user_id,
|
user_id,
|
||||||
request_id,
|
request_id,
|
||||||
session_id,
|
session_id,
|
||||||
@@ -334,6 +404,33 @@ async def _handle_floating_request(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_contextual_scope_update(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a contextual_scope_update frame.
|
||||||
|
|
||||||
|
Injects a synthetic system message into the session buffer so the next
|
||||||
|
agent turn knows the user navigated. No LLM call is made.
|
||||||
|
"""
|
||||||
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
scope = ContextualScope.model_validate(frame.get("scope", {}))
|
||||||
|
block = render_scope_block(scope)
|
||||||
|
buf = get_session_buffer(user_id, session_id, channel="contextual")
|
||||||
|
buf.append_system_message(
|
||||||
|
f"User navigated to a new view. {block} Treat this as the new active context."
|
||||||
|
)
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.contextual_scope_ack,
|
||||||
|
"session_id": session_id,
|
||||||
|
}))
|
||||||
|
logger.info(
|
||||||
|
"device_ws: contextual_scope_update user=%s session=%s page=%s",
|
||||||
|
user_id, session_id, scope.page,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _handle_brief_request(
|
async def _handle_brief_request(
|
||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
@@ -415,6 +512,98 @@ async def _handle_brief_request(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── v6 Task Brief Handler ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_task_brief_request(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a task_brief_request frame — Stage-1 executive assistant deep research.
|
||||||
|
|
||||||
|
Streams the briefing markdown back to the client.
|
||||||
|
On stream_end, emits a ``canvas_draft`` mutation if the agent produced one.
|
||||||
|
"""
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
session_id = frame.get("session_id") or str(uuid4())
|
||||||
|
task_id: str = frame.get("task_id") or frame.get("taskId") or ""
|
||||||
|
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"device_ws: task_brief_request_start user=%s req=%s task=%s project=%s [cache_miss]",
|
||||||
|
user_id, request_id, task_id, project_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not task_id:
|
||||||
|
await websocket.send_text(
|
||||||
|
WsStreamEnd(request_id=request_id, error="task_id is required").model_dump_json()
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id,
|
||||||
|
f"task brief: {task_id}",
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
context: dict = {
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
|
"format_prefs": frame.get("format_prefs"),
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
response_chunks: list[str] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
event_stream = run_task_brief_research_stream(user_id, task_id, context, project_id=project_id)
|
||||||
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
|
async for ws_frame in formatter.format(event_stream):
|
||||||
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
elif ws_frame.type == "stream_start":
|
||||||
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
# stream_end is emitted below with mutations — skip formatter's version
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: task_brief_request failed user=%s req=%s task=%s: %s",
|
||||||
|
user_id, request_id, task_id, exc,
|
||||||
|
)
|
||||||
|
await websocket.send_text(
|
||||||
|
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
||||||
|
)
|
||||||
|
return
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
# Extract canvas block then emit stream_end with optional mutations.
|
||||||
|
full_response = "".join(response_chunks)
|
||||||
|
_visible, canvas_content, canvas_kind = extract_canvas_block(full_response)
|
||||||
|
|
||||||
|
mutations: list[dict] = []
|
||||||
|
if canvas_content:
|
||||||
|
mutations.append({
|
||||||
|
"type": "canvas_draft",
|
||||||
|
"content": canvas_content,
|
||||||
|
"kind": canvas_kind,
|
||||||
|
})
|
||||||
|
|
||||||
|
await websocket.send_text(
|
||||||
|
WsStreamEnd(request_id=request_id, mutations=mutations or None).model_dump_json()
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"device_ws: task_brief_request_end user=%s req=%s task=%s response_chars=%d canvas=%s",
|
||||||
|
user_id, request_id, task_id, len(full_response), canvas_kind or "none",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -472,6 +661,174 @@ async def _handle_journey_message(
|
|||||||
clear_client_executor()
|
clear_client_executor()
|
||||||
|
|
||||||
|
|
||||||
|
# ── v7 Folder Index Handlers ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_index_session_start(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Register a new folder index session. No response sent — client is declaring intent."""
|
||||||
|
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||||
|
project_id: str | None = frame.get("projectId") or frame.get("project_id")
|
||||||
|
total: int = int(frame.get("totalFiles") or frame.get("total_files") or 0)
|
||||||
|
|
||||||
|
if not session_id:
|
||||||
|
logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
_index_sessions[session_id] = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"project_id": project_id,
|
||||||
|
"processed": 0,
|
||||||
|
"total": total,
|
||||||
|
"cancelled": False,
|
||||||
|
}
|
||||||
|
logger.info(
|
||||||
|
"device_ws: index_session_start user=%s session=%s project=%s total=%d",
|
||||||
|
user_id, session_id, project_id, total,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_index_session_cancel(
|
||||||
|
websocket: WebSocket,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Mark a session as cancelled and emit index_session_done(cancelled)."""
|
||||||
|
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||||
|
session = _index_sessions.get(session_id)
|
||||||
|
if session:
|
||||||
|
session["cancelled"] = True
|
||||||
|
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_session_done,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"status": "cancelled",
|
||||||
|
}))
|
||||||
|
_index_sessions.pop(session_id, None)
|
||||||
|
logger.info("device_ws: index_session_cancel session=%s", session_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_index_file_batch(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Process a batch of files for an index session, streaming results back."""
|
||||||
|
# Lazy imports to avoid heavy load at module startup.
|
||||||
|
from app.core.folder_indexer import ( # noqa: PLC0415
|
||||||
|
summarize_image,
|
||||||
|
summarize_pdf,
|
||||||
|
summarize_docx,
|
||||||
|
summarize_text,
|
||||||
|
)
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
from app.billing.quota import add_token_usage # noqa: PLC0415
|
||||||
|
|
||||||
|
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||||
|
files: list[dict] = frame.get("files", [])
|
||||||
|
|
||||||
|
session = _index_sessions.get(session_id)
|
||||||
|
if not session or session.get("cancelled"):
|
||||||
|
return
|
||||||
|
|
||||||
|
async with async_session() as db:
|
||||||
|
tier = await tier_manager.get_tier(user_id, db)
|
||||||
|
raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
||||||
|
cap: int | None = None if raw_cap == -1 else raw_cap
|
||||||
|
|
||||||
|
for file_info in files:
|
||||||
|
if session.get("cancelled"):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Electron's toSnakeCase converts payload keys, so accept both forms.
|
||||||
|
rel_path: str = file_info.get("relPath") or file_info.get("rel_path") or ""
|
||||||
|
kind: str = file_info.get("kind") or "text"
|
||||||
|
content: str = file_info.get("content") or ""
|
||||||
|
ext: str = file_info.get("ext") or ""
|
||||||
|
mime: str = file_info.get("mime") or "application/octet-stream"
|
||||||
|
name: str = rel_path.split("/")[-1] or rel_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
if kind == "image":
|
||||||
|
res = await summarize_image(image_b64=content, mime=mime)
|
||||||
|
elif kind == "pdf":
|
||||||
|
res = await summarize_pdf(pdf_b64=content, name=name)
|
||||||
|
elif kind == "docx":
|
||||||
|
res = await summarize_docx(docx_b64=content, name=name)
|
||||||
|
else:
|
||||||
|
res = await summarize_text(content=content, ext=ext, name=name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: index_file_batch summarize failed session=%s path=%s: %s",
|
||||||
|
session_id, rel_path, exc,
|
||||||
|
)
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_file_result,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"relPath": rel_path,
|
||||||
|
"summary": None,
|
||||||
|
"tokensUsed": 0,
|
||||||
|
"error": str(exc),
|
||||||
|
}))
|
||||||
|
session["processed"] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Account for token usage and check cap.
|
||||||
|
usage = await add_token_usage(
|
||||||
|
user_id=user_id,
|
||||||
|
feature="folder_index",
|
||||||
|
tokens=res.tokens_used,
|
||||||
|
db=db,
|
||||||
|
cap=cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_file_result,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"relPath": rel_path,
|
||||||
|
"summary": res.summary,
|
||||||
|
"tokensUsed": res.tokens_used,
|
||||||
|
}))
|
||||||
|
session["processed"] += 1
|
||||||
|
|
||||||
|
if usage.exhausted:
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_session_done,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"status": "quota_exceeded",
|
||||||
|
}))
|
||||||
|
_index_sessions.pop(session_id, None)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: index_session quota_exceeded user=%s session=%s",
|
||||||
|
user_id, session_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# After processing the batch, emit progress.
|
||||||
|
processed = session["processed"]
|
||||||
|
total = session["total"]
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_session_progress,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"processed": processed,
|
||||||
|
"total": total,
|
||||||
|
}))
|
||||||
|
|
||||||
|
if processed >= total:
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_session_done,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"status": "completed",
|
||||||
|
}))
|
||||||
|
_index_sessions.pop(session_id, None)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: index_session_done completed user=%s session=%s processed=%d",
|
||||||
|
user_id, session_id, processed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||||
@@ -484,14 +841,14 @@ async def _heartbeat_loop(websocket: WebSocket) -> None:
|
|||||||
# ── Disconnect cleanup ────────────────────────────────────────────────
|
# ── Disconnect cleanup ────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _mark_runs_disconnected(user_id: str) -> None:
|
async def _mark_runs_disconnected(user_id: str) -> None:
|
||||||
"""Mark all in-progress AgentRunLog rows as 'error' for this user."""
|
"""Mark all in-progress ScoutRunLog rows as 'error' for this user."""
|
||||||
try:
|
try:
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
await db.execute(
|
await db.execute(
|
||||||
update(AgentRunLog)
|
update(ScoutRunLog)
|
||||||
.where(
|
.where(
|
||||||
AgentRunLog.user_id == user_id,
|
ScoutRunLog.user_id == user_id,
|
||||||
AgentRunLog.status == "running",
|
ScoutRunLog.status == "running",
|
||||||
)
|
)
|
||||||
.values(
|
.values(
|
||||||
status="error",
|
status="error",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Chatbot Journey — WS-based guided conversation to build an AgentConfig.
|
"""Chatbot Journey — WS-based guided conversation to build an ScoutConfig.
|
||||||
|
|
||||||
The journey is driven entirely through WebSocket frames (no REST endpoints).
|
The journey is driven entirely through WebSocket frames (no REST endpoints).
|
||||||
The device WS handler dispatches ``journey_start`` and ``journey_message``
|
The device WS handler dispatches ``journey_start`` and ``journey_message``
|
||||||
@@ -13,7 +13,7 @@ Journey flow:
|
|||||||
3. FE sends ``journey_message`` frames for each user reply.
|
3. FE sends ``journey_message`` frames for each user reply.
|
||||||
4. Server appends the user message, calls the LLM (which may read files
|
4. Server appends the user message, calls the LLM (which may read files
|
||||||
via tools), and sends back a ``journey_reply``.
|
via tools), and sends back a ``journey_reply``.
|
||||||
5. After 3-5 turns the LLM wraps up by emitting an ``AgentConfig`` JSON
|
5. After 3-5 turns the LLM wraps up by emitting an ``ScoutConfig`` JSON
|
||||||
block delimited by ``AGENT_CONFIG_START`` / ``AGENT_CONFIG_END``.
|
block delimited by ``AGENT_CONFIG_START`` / ``AGENT_CONFIG_END``.
|
||||||
6. Server parses and validates the JSON with Pydantic, sends
|
6. Server parses and validates the JSON with Pydantic, sends
|
||||||
``journey_reply`` with ``done=True`` and the serialised config.
|
``journey_reply`` with ``done=True`` and the serialised config.
|
||||||
@@ -34,7 +34,7 @@ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, Tool
|
|||||||
from app.agents.filesystem_agent import make_directory_tools
|
from app.agents.filesystem_agent import make_directory_tools
|
||||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
||||||
from app.core.llm import get_agent_llm, model_for_agent
|
from app.core.llm import get_agent_llm, model_for_agent
|
||||||
from app.schemas import AgentConfig
|
from app.schemas import ScoutConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||||
|
|
||||||
# Sentinel strings used to delimit the LLM-produced AgentConfig JSON.
|
# Sentinel strings used to delimit the LLM-produced ScoutConfig JSON.
|
||||||
_CONFIG_START = "AGENT_CONFIG_START"
|
_CONFIG_START = "AGENT_CONFIG_START"
|
||||||
_CONFIG_END = "AGENT_CONFIG_END"
|
_CONFIG_END = "AGENT_CONFIG_END"
|
||||||
|
|
||||||
@@ -92,7 +92,7 @@ def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
|||||||
_JOURNEY_SYSTEM_PROMPT = """\
|
_JOURNEY_SYSTEM_PROMPT = """\
|
||||||
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||||
Your job is to understand what files the user has in their directory and produce a
|
Your job is to understand what files the user has in their directory and produce a
|
||||||
structured AgentConfig JSON that the extraction agent will use as its instruction set.
|
structured ScoutConfig JSON that the extraction agent will use as its instruction set.
|
||||||
|
|
||||||
You have access to file-system tools to explore the user's directory:
|
You have access to file-system tools to explore the user's directory:
|
||||||
- list_directory: see folder structure and file names
|
- list_directory: see folder structure and file names
|
||||||
@@ -122,7 +122,7 @@ Cover these topics based on what you discovered:
|
|||||||
4. Date extraction (e.g. "by Friday" → dueDate)
|
4. Date extraction (e.g. "by Friday" → dueDate)
|
||||||
5. Exclusion rules (e.g. skip newsletters, skip files with no project match)
|
5. Exclusion rules (e.g. skip newsletters, skip files with no project match)
|
||||||
|
|
||||||
### Step 4 — Produce the AgentConfig JSON
|
### Step 4 — Produce the ScoutConfig JSON
|
||||||
Once you are ≥ 90% confident, output the final config between these exact markers
|
Once you are ≥ 90% confident, output the final config between these exact markers
|
||||||
(each on its own line):
|
(each on its own line):
|
||||||
|
|
||||||
@@ -168,7 +168,7 @@ def _build_system_prompt(
|
|||||||
) -> tuple[str, Any]:
|
) -> tuple[str, Any]:
|
||||||
"""Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``."""
|
"""Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``."""
|
||||||
existing_section = (
|
existing_section = (
|
||||||
"\nThe user already has the following AgentConfig — refine it based on their answers:\n"
|
"\nThe user already has the following ScoutConfig — refine it based on their answers:\n"
|
||||||
f"```json\n{existing_config}\n```\n"
|
f"```json\n{existing_config}\n```\n"
|
||||||
if existing_config
|
if existing_config
|
||||||
else ""
|
else ""
|
||||||
@@ -189,11 +189,11 @@ def _build_system_prompt(
|
|||||||
return compiled, prompt_obj
|
return compiled, prompt_obj
|
||||||
|
|
||||||
|
|
||||||
# ── AgentConfig extraction ────────────────────────────────────────────────
|
# ── ScoutConfig extraction ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _extract_agent_config(text: str) -> str | None:
|
def _extract_agent_config(text: str) -> str | None:
|
||||||
"""Return validated AgentConfig JSON string from between markers, or None.
|
"""Return validated ScoutConfig JSON string from between markers, or None.
|
||||||
|
|
||||||
Parses the JSON with Pydantic to ensure it conforms to the schema before
|
Parses the JSON with Pydantic to ensure it conforms to the schema before
|
||||||
returning. Returns None if markers are absent or JSON is invalid.
|
returning. Returns None if markers are absent or JSON is invalid.
|
||||||
@@ -206,10 +206,10 @@ def _extract_agent_config(text: str) -> str | None:
|
|||||||
if not raw:
|
if not raw:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
parsed = AgentConfig.model_validate_json(raw)
|
parsed = ScoutConfig.model_validate_json(raw)
|
||||||
return parsed.model_dump_json()
|
return parsed.model_dump_json()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("agent_setup: failed to parse AgentConfig JSON: %s", exc)
|
logger.warning("agent_setup: failed to parse ScoutConfig JSON: %s", exc)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -475,7 +475,7 @@ async def handle_journey_message(
|
|||||||
if turns >= _MAX_TURNS:
|
if turns >= _MAX_TURNS:
|
||||||
nudge_content = (
|
nudge_content = (
|
||||||
"[System: You have enough information. Please generate the final "
|
"[System: You have enough information. Please generate the final "
|
||||||
f"AgentConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]"
|
f"ScoutConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]"
|
||||||
)
|
)
|
||||||
session.history.append({"role": "user", "content": nudge_content})
|
session.history.append({"role": "user", "content": nudge_content})
|
||||||
|
|
||||||
120
app/api/routes/scout_webhooks.py
Normal file
120
app/api/routes/scout_webhooks.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""Gmail Pub/Sub push receiver.
|
||||||
|
|
||||||
|
Google Pub/Sub push subscriptions deliver Gmail watch notifications as POST
|
||||||
|
requests with a JSON envelope. The body payload contains a base64-encoded
|
||||||
|
JSON blob with ``emailAddress`` + ``historyId``. We resolve the user by
|
||||||
|
email, look up their cloud_scout_configs row for provider='gmail', and
|
||||||
|
hand off to ScoutEngine.trigger_scout.
|
||||||
|
|
||||||
|
Authentication: Pub/Sub push includes an OIDC JWT in the Authorization
|
||||||
|
header. We verify it against Google's public keys with the audience
|
||||||
|
configured in our Pub/Sub subscription.
|
||||||
|
|
||||||
|
Dev mode: when ``GMAIL_PUBSUB_AUDIENCE`` is empty, JWT verification is
|
||||||
|
skipped and a warning is logged. Production must set this env var.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Header, HTTPException, Request, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.db import async_session
|
||||||
|
from app.models import CloudScoutConfig, User
|
||||||
|
from app.scouts.engine import ScoutEngine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter(prefix="/scouts/webhooks", tags=["scout-webhooks"])
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_pubsub_jwt(token: str) -> bool:
|
||||||
|
"""Verify the Google Pub/Sub OIDC JWT.
|
||||||
|
|
||||||
|
Returns True when valid, False on any verification failure.
|
||||||
|
|
||||||
|
Dev skip: if ``settings.GMAIL_PUBSUB_AUDIENCE`` is empty, logs a
|
||||||
|
warning and returns True so local development works without a real
|
||||||
|
Pub/Sub subscription. Production must configure the audience.
|
||||||
|
"""
|
||||||
|
if not token:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not settings.GMAIL_PUBSUB_AUDIENCE:
|
||||||
|
logger.warning(
|
||||||
|
"GMAIL_PUBSUB_AUDIENCE not set — skipping Pub/Sub JWT verification (dev mode only)"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
from google.auth.transport import requests as g_requests # noqa: PLC0415
|
||||||
|
from google.oauth2 import id_token # noqa: PLC0415
|
||||||
|
|
||||||
|
id_token.verify_oauth2_token(
|
||||||
|
token,
|
||||||
|
g_requests.Request(),
|
||||||
|
audience=settings.GMAIL_PUBSUB_AUDIENCE,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
logger.warning("pubsub jwt verification failed", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/gmail", status_code=status.HTTP_204_NO_CONTENT)
|
||||||
|
async def gmail_pubsub(
|
||||||
|
request: Request,
|
||||||
|
authorization: str = Header(default=""),
|
||||||
|
) -> None:
|
||||||
|
"""Receive a Gmail Pub/Sub push notification.
|
||||||
|
|
||||||
|
Verifies the OIDC JWT, decodes the Pub/Sub envelope, resolves the user
|
||||||
|
by email, and triggers ScoutEngine.trigger_scout for each enabled Gmail
|
||||||
|
scout belonging to that user.
|
||||||
|
|
||||||
|
Returns 204 No Content on success (including benign no-ops like unknown
|
||||||
|
email or empty message data). Returns 401 on JWT verification failure.
|
||||||
|
"""
|
||||||
|
token = authorization.removeprefix("Bearer ").strip()
|
||||||
|
if not _verify_pubsub_jwt(token):
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid Pub/Sub JWT")
|
||||||
|
|
||||||
|
body = await request.json()
|
||||||
|
msg = body.get("message") or {}
|
||||||
|
raw = msg.get("data")
|
||||||
|
if not raw:
|
||||||
|
return # ack without action — empty message data
|
||||||
|
|
||||||
|
try:
|
||||||
|
decoded = json.loads(base64.b64decode(raw).decode())
|
||||||
|
except Exception:
|
||||||
|
logger.warning("pubsub payload decode failed")
|
||||||
|
return
|
||||||
|
|
||||||
|
email = decoded.get("emailAddress")
|
||||||
|
if not email:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
user_q = await session.execute(select(User).where(User.email == email))
|
||||||
|
user = user_q.scalar_one_or_none()
|
||||||
|
if user is None:
|
||||||
|
logger.info("pubsub: no user for %s — ignoring", email)
|
||||||
|
return
|
||||||
|
scouts_q = await session.execute(
|
||||||
|
select(CloudScoutConfig).where(
|
||||||
|
CloudScoutConfig.user_id == user.id,
|
||||||
|
CloudScoutConfig.provider == "gmail",
|
||||||
|
CloudScoutConfig.enabled == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
scouts = scouts_q.scalars().all()
|
||||||
|
|
||||||
|
engine = ScoutEngine()
|
||||||
|
for scout in scouts:
|
||||||
|
await engine.trigger_scout(uuid.UUID(str(scout.id)))
|
||||||
440
app/api/routes/scouts.py
Normal file
440
app/api/routes/scouts.py
Normal file
@@ -0,0 +1,440 @@
|
|||||||
|
"""Scout routes.
|
||||||
|
|
||||||
|
Backend responsibilities are intentionally minimal:
|
||||||
|
GET /scouts/catalog — static catalog for UI display
|
||||||
|
POST /scouts/can-create — billing eligibility check
|
||||||
|
POST /scouts/trigger — trigger a local scout run
|
||||||
|
|
||||||
|
Scout configuration is owned by the Electron app and is not persisted
|
||||||
|
in backend scout-config tables.
|
||||||
|
|
||||||
|
Gmail OAuth setup (scout-specific consent):
|
||||||
|
GET /scouts/oauth/gmail/authorize — returns consent-screen URL
|
||||||
|
GET /scouts/oauth/gmail/web-callback — bounces to deep link (excluded from schema)
|
||||||
|
POST /scouts/oauth/gmail/callback — exchanges code, stores encrypted token
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
import urllib.parse
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.auth.oauth_providers import generate_pkce_pair
|
||||||
|
from app.billing.tier_manager import FEATURES
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.core.scout_runner import is_agent_running, run_local_agent
|
||||||
|
from app.core.device_manager import device_manager
|
||||||
|
from app.core.note_summarizer import generate_note_summary
|
||||||
|
from app.db import get_session
|
||||||
|
from app.integrations import encrypt_token
|
||||||
|
from app.models import CloudScoutConfig, ScoutRunLog, LocalScoutConfig
|
||||||
|
from app.schemas import (
|
||||||
|
ScoutCatalogItem,
|
||||||
|
ScoutCreationCheckRequest,
|
||||||
|
ScoutCreationCheckResponse,
|
||||||
|
ScoutRunLogResponse,
|
||||||
|
ScoutTriggerRequest,
|
||||||
|
UserProfile,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/scouts", tags=["scouts"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Datetime helpers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _dt_ms(dt: datetime) -> int:
|
||||||
|
return int(dt.timestamp() * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
def _dt_ms_opt(dt: datetime | None) -> int | None:
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
|
def _to_data_types(values: list[str]) -> list[str]:
|
||||||
|
normalize = {
|
||||||
|
"task": "tasks", "tasks": "tasks",
|
||||||
|
"note": "notes", "notes": "notes",
|
||||||
|
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
||||||
|
"project": "projects", "projects": "projects",
|
||||||
|
}
|
||||||
|
seen: set[str] = set()
|
||||||
|
result: list[str] = []
|
||||||
|
for v in values:
|
||||||
|
mapped = normalize.get(v)
|
||||||
|
if mapped and mapped not in seen:
|
||||||
|
seen.add(mapped)
|
||||||
|
result.append(mapped)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _to_run_log_response(log: ScoutRunLog) -> ScoutRunLogResponse:
|
||||||
|
return ScoutRunLogResponse(
|
||||||
|
id=log.id,
|
||||||
|
agent_id=log.scout_id,
|
||||||
|
agent_type=log.scout_type, # type: ignore[arg-type]
|
||||||
|
status=log.status, # type: ignore[arg-type]
|
||||||
|
items_processed=log.items_processed,
|
||||||
|
items_created=log.items_created,
|
||||||
|
errors=log.errors or [],
|
||||||
|
started_at=_dt_ms(log.started_at),
|
||||||
|
completed_at=_dt_ms_opt(log.completed_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
||||||
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||||
|
if limit != -1 and current_count >= limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||||
|
)
|
||||||
|
return limit
|
||||||
|
|
||||||
|
|
||||||
|
async def _enforce_run_frequency(
|
||||||
|
tier: str,
|
||||||
|
user_id: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
|
||||||
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
||||||
|
if limit == -1:
|
||||||
|
return # unlimited
|
||||||
|
|
||||||
|
today_start = datetime.now(timezone.utc).replace(
|
||||||
|
hour=0, minute=0, second=0, microsecond=0
|
||||||
|
)
|
||||||
|
result = await db.execute(
|
||||||
|
select(func.count(ScoutRunLog.id)).where(
|
||||||
|
ScoutRunLog.user_id == user_id,
|
||||||
|
ScoutRunLog.started_at >= today_start,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
runs_today: int = result.scalar_one()
|
||||||
|
|
||||||
|
if runs_today >= limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Catalog ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/catalog", response_model=list[ScoutCatalogItem])
|
||||||
|
async def get_agent_catalog(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> list[ScoutCatalogItem]:
|
||||||
|
"""Return the static list of available agent types and their descriptions."""
|
||||||
|
return [
|
||||||
|
ScoutCatalogItem(
|
||||||
|
type="local_directory",
|
||||||
|
name="Local Directory Monitor",
|
||||||
|
description="Watches local directories, extracts data from files using AI",
|
||||||
|
),
|
||||||
|
ScoutCatalogItem(
|
||||||
|
type="gmail",
|
||||||
|
name="Gmail Connector",
|
||||||
|
description="Scans Gmail inbox, extracts tasks/notes from emails",
|
||||||
|
),
|
||||||
|
ScoutCatalogItem(
|
||||||
|
type="teams",
|
||||||
|
name="Microsoft Teams Connector",
|
||||||
|
description="Monitors Teams messages, extracts action items",
|
||||||
|
),
|
||||||
|
ScoutCatalogItem(
|
||||||
|
type="outlook",
|
||||||
|
name="Outlook Connector",
|
||||||
|
description="Scans Outlook inbox, extracts tasks/notes",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/can-create", response_model=ScoutCreationCheckResponse)
|
||||||
|
async def can_create_agent(
|
||||||
|
body: ScoutCreationCheckRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> ScoutCreationCheckResponse:
|
||||||
|
"""Check if the user can create one more agent based on billing tier.
|
||||||
|
|
||||||
|
Since configuration is client-owned, the Electron app sends its current
|
||||||
|
active agent count and the backend applies tier limits.
|
||||||
|
"""
|
||||||
|
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
|
||||||
|
allowed = limit == -1 or body.active_agents < limit
|
||||||
|
return ScoutCreationCheckResponse(
|
||||||
|
allowed=allowed,
|
||||||
|
tier=current_user.tier,
|
||||||
|
active_agents=body.active_agents,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/trigger", response_model=ScoutRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
||||||
|
async def trigger_agent_run(
|
||||||
|
body: ScoutTriggerRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> ScoutRunLogResponse:
|
||||||
|
"""Trigger a local agent run using client-provided configuration."""
|
||||||
|
_enforce_agent_limit(current_user.tier, body.active_agents)
|
||||||
|
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
||||||
|
|
||||||
|
last_run_dt = (
|
||||||
|
datetime.fromtimestamp(body.last_run_at / 1000, tz=timezone.utc)
|
||||||
|
if body.last_run_at
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
config = LocalScoutConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=current_user.id,
|
||||||
|
device_id=body.device_id,
|
||||||
|
name="Local Directory Monitor",
|
||||||
|
directory_paths=[body.directory],
|
||||||
|
data_types=_to_data_types(body.what_to_extract),
|
||||||
|
prompt_template=body.custom_agent_prompt or "",
|
||||||
|
scout_config=body.agent_config,
|
||||||
|
file_extensions=[],
|
||||||
|
schedule_cron=body.batch_interval,
|
||||||
|
enabled=True,
|
||||||
|
last_run_at=last_run_dt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
||||||
|
stable_agent_id = body.agent_id or config.id
|
||||||
|
|
||||||
|
if is_agent_running(stable_agent_id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Agent is already running. Only one run per agent is allowed at a time.",
|
||||||
|
)
|
||||||
|
|
||||||
|
run_log = ScoutRunLog(
|
||||||
|
scout_id=stable_agent_id,
|
||||||
|
scout_type="local",
|
||||||
|
user_id=current_user.id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
|
||||||
|
run_context = {
|
||||||
|
"type": "agent_batch",
|
||||||
|
"run_id": run_log.id,
|
||||||
|
"agent_id": stable_agent_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
asyncio.create_task(
|
||||||
|
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
|
||||||
|
)
|
||||||
|
|
||||||
|
return _to_run_log_response(run_log)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Note summary endpoint ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class NoteSummarizeRequest(BaseModel):
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class NoteSummarizeResponse(BaseModel):
|
||||||
|
summary: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/notes/summarize", response_model=NoteSummarizeResponse)
|
||||||
|
async def summarize_note(
|
||||||
|
body: NoteSummarizeRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> NoteSummarizeResponse:
|
||||||
|
"""Generate an AI summary for a note. Used by the Electron backfill on startup."""
|
||||||
|
summary = await generate_note_summary(body.title, body.content)
|
||||||
|
return NoteSummarizeResponse(summary=summary)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Gmail OAuth setup (scout-specific) ───────────────────────────────────────
|
||||||
|
|
||||||
|
# Scopes required for Gmail scout connectivity.
|
||||||
|
_GMAIL_SCOUT_SCOPES = [
|
||||||
|
"openid",
|
||||||
|
"email",
|
||||||
|
"https://www.googleapis.com/auth/gmail.readonly",
|
||||||
|
"https://www.googleapis.com/auth/gmail.modify",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Google OAuth endpoints.
|
||||||
|
_GOOGLE_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||||
|
_GOOGLE_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||||
|
|
||||||
|
# In-memory pending OAuth states for scout Gmail consent:
|
||||||
|
# state → (code_verifier, scout_id, user_id, expires_at_epoch_s)
|
||||||
|
# Production note: replace with Redis for multi-process deployments.
|
||||||
|
_pending_scout_oauth_states: dict[str, tuple[str, str, str, float]] = {}
|
||||||
|
_SCOUT_OAUTH_TTL_SECONDS = 600 # 10 minutes
|
||||||
|
|
||||||
|
|
||||||
|
def _scout_gmail_redirect_uri() -> str:
|
||||||
|
"""Derive the scout Gmail web-callback URI from the configured base OAUTH_REDIRECT_URI.
|
||||||
|
|
||||||
|
``OAUTH_REDIRECT_URI`` is the full path used for login OAuth
|
||||||
|
(e.g. http://localhost:8000/api/v1/auth/oauth/google/web-callback).
|
||||||
|
We strip the path to get the scheme+host base, then append the scout path.
|
||||||
|
"""
|
||||||
|
parsed = urllib.parse.urlparse(settings.OAUTH_REDIRECT_URI)
|
||||||
|
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||||
|
return f"{base}/api/v1/scouts/oauth/gmail/web-callback"
|
||||||
|
|
||||||
|
|
||||||
|
class _ScoutGmailAuthorizeResponse(BaseModel):
|
||||||
|
authorize_url: str
|
||||||
|
|
||||||
|
|
||||||
|
class _ScoutGmailCallbackBody(BaseModel):
|
||||||
|
code: str
|
||||||
|
state: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/oauth/gmail/authorize", response_model=_ScoutGmailAuthorizeResponse)
|
||||||
|
async def scout_gmail_oauth_authorize(
|
||||||
|
scout_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> _ScoutGmailAuthorizeResponse:
|
||||||
|
"""Start the Gmail OAuth flow for a specific cloud scout.
|
||||||
|
|
||||||
|
Returns the Google consent-screen URL. The client opens this URL in the
|
||||||
|
system browser; after consent Google redirects to web-callback which bounces
|
||||||
|
to the ``adiuvai://scout/oauth/gmail/callback`` deep link.
|
||||||
|
"""
|
||||||
|
if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
"Google OAuth is not configured on this server",
|
||||||
|
)
|
||||||
|
|
||||||
|
code_verifier, code_challenge = generate_pkce_pair()
|
||||||
|
state = secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
# Purge expired states to prevent unbounded growth.
|
||||||
|
now = time.time()
|
||||||
|
expired = [s for s, (_, _, _, exp) in _pending_scout_oauth_states.items() if exp < now]
|
||||||
|
for s in expired:
|
||||||
|
del _pending_scout_oauth_states[s]
|
||||||
|
|
||||||
|
_pending_scout_oauth_states[state] = (code_verifier, scout_id, current_user.id, now + _SCOUT_OAUTH_TTL_SECONDS)
|
||||||
|
|
||||||
|
redirect_uri = _scout_gmail_redirect_uri()
|
||||||
|
params = {
|
||||||
|
"client_id": settings.GOOGLE_AUTH_CLIENT_ID,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"response_type": "code",
|
||||||
|
"scope": " ".join(_GMAIL_SCOUT_SCOPES),
|
||||||
|
"state": state,
|
||||||
|
"code_challenge": code_challenge,
|
||||||
|
"code_challenge_method": "S256",
|
||||||
|
"access_type": "offline",
|
||||||
|
"prompt": "consent",
|
||||||
|
}
|
||||||
|
authorize_url = f"{_GOOGLE_AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||||
|
return _ScoutGmailAuthorizeResponse(authorize_url=authorize_url)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/oauth/gmail/web-callback", include_in_schema=False)
|
||||||
|
async def scout_gmail_oauth_web_callback(code: str, state: str) -> RedirectResponse:
|
||||||
|
"""Google redirects here after Gmail consent.
|
||||||
|
|
||||||
|
Immediately bounces to the Electron deep link so the desktop app
|
||||||
|
receives the authorization code.
|
||||||
|
"""
|
||||||
|
params = urllib.parse.urlencode({"code": code, "state": state})
|
||||||
|
deep_link = f"adiuvai://scout/oauth/gmail/callback?{params}"
|
||||||
|
return RedirectResponse(url=deep_link, status_code=302)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/oauth/gmail/callback")
|
||||||
|
async def scout_gmail_oauth_callback(
|
||||||
|
body: _ScoutGmailCallbackBody,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict:
|
||||||
|
"""Exchange the Gmail authorization code and store the encrypted token on the scout.
|
||||||
|
|
||||||
|
Called by the Electron app after it receives the deep-link callback with
|
||||||
|
the ``code`` and ``state`` params.
|
||||||
|
"""
|
||||||
|
entry = _pending_scout_oauth_states.pop(body.state, None)
|
||||||
|
if entry is None or entry[3] < time.time() or entry[2] != current_user.id:
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state")
|
||||||
|
|
||||||
|
code_verifier, scout_id, _, _ = entry
|
||||||
|
|
||||||
|
redirect_uri = _scout_gmail_redirect_uri()
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
_GOOGLE_TOKEN_URL,
|
||||||
|
data={
|
||||||
|
"client_id": settings.GOOGLE_AUTH_CLIENT_ID,
|
||||||
|
"client_secret": settings.GOOGLE_AUTH_CLIENT_SECRET,
|
||||||
|
"code": body.code,
|
||||||
|
"code_verifier": code_verifier,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
logger.error("Gmail token exchange failed: %s", exc.response.text)
|
||||||
|
raise HTTPException(status.HTTP_502_BAD_GATEWAY, "Failed to exchange Gmail authorization code")
|
||||||
|
|
||||||
|
token_data = response.json()
|
||||||
|
|
||||||
|
creds_dict: dict = {
|
||||||
|
"token": token_data["access_token"],
|
||||||
|
"refresh_token": token_data.get("refresh_token"),
|
||||||
|
"token_uri": _GOOGLE_TOKEN_URL,
|
||||||
|
"client_id": settings.GOOGLE_AUTH_CLIENT_ID,
|
||||||
|
"client_secret": settings.GOOGLE_AUTH_CLIENT_SECRET,
|
||||||
|
"scopes": [
|
||||||
|
"https://www.googleapis.com/auth/gmail.readonly",
|
||||||
|
"https://www.googleapis.com/auth/gmail.modify",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
encrypted = encrypt_token(creds_dict)
|
||||||
|
|
||||||
|
scout = await db.get(CloudScoutConfig, scout_id)
|
||||||
|
if scout is None or scout.user_id != current_user.id:
|
||||||
|
raise HTTPException(status.HTTP_404_NOT_FOUND, "Scout not found")
|
||||||
|
scout.oauth_token_encrypted = encrypted
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# Attempt to set up Gmail push watch so we start receiving Pub/Sub notifications.
|
||||||
|
from app.scouts.connectors.registry import get_connector
|
||||||
|
try:
|
||||||
|
connector = get_connector("gmail")
|
||||||
|
await connector.setup_watch(scout)
|
||||||
|
await db.commit()
|
||||||
|
except KeyError:
|
||||||
|
logger.warning("gmail connector not registered — skipping setup_watch for scout %s", scout_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("setup_watch failed for scout %s", scout_id)
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
139
app/billing/quota.py
Normal file
139
app/billing/quota.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
"""Quota checks and atomic token-usage accounting for folder integration."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.billing.tier_manager import TierManager
|
||||||
|
from app.models import MonthlyTokenUsage
|
||||||
|
from app.schemas import BillingTier
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaExceeded(Exception):
|
||||||
|
"""Raised when a folder operation cannot proceed under the user's tier."""
|
||||||
|
|
||||||
|
def __init__(self, reason: str, message: str) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
self.reason = reason # "max_files" | "monthly_tokens"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenUsageResult:
|
||||||
|
tokens_used: int
|
||||||
|
exhausted: bool
|
||||||
|
|
||||||
|
|
||||||
|
def _current_year_month() -> str:
|
||||||
|
return datetime.now(timezone.utc).strftime("%Y-%m")
|
||||||
|
|
||||||
|
|
||||||
|
_tier_manager = TierManager()
|
||||||
|
|
||||||
|
|
||||||
|
async def check_folder_quota(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
tier: BillingTier,
|
||||||
|
estimated_files: int,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Raise QuotaExceeded if folder_max_files or folder_monthly_tokens
|
||||||
|
would be violated. -1 in either feature means unlimited."""
|
||||||
|
max_files = _tier_manager.get_feature_value(tier, "folder_max_files")
|
||||||
|
if max_files != -1 and estimated_files > max_files:
|
||||||
|
raise QuotaExceeded(
|
||||||
|
"max_files",
|
||||||
|
f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
||||||
|
if cap == -1:
|
||||||
|
return
|
||||||
|
ym = _current_year_month()
|
||||||
|
row = (
|
||||||
|
await db.execute(
|
||||||
|
select(MonthlyTokenUsage).where(
|
||||||
|
MonthlyTokenUsage.user_id == user_id,
|
||||||
|
MonthlyTokenUsage.year_month == ym,
|
||||||
|
MonthlyTokenUsage.feature == "folder_index",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
used = row.tokens_used if row else 0
|
||||||
|
if used >= cap:
|
||||||
|
raise QuotaExceeded(
|
||||||
|
"monthly_tokens",
|
||||||
|
f"Monthly token budget exhausted ({used}/{cap}); resets next month.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def add_token_usage(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
feature: str,
|
||||||
|
tokens: int,
|
||||||
|
db: AsyncSession,
|
||||||
|
cap: int | None = None,
|
||||||
|
) -> TokenUsageResult:
|
||||||
|
"""Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature).
|
||||||
|
|
||||||
|
Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls
|
||||||
|
back to a read-then-write on other engines (e.g. aiosqlite in tests).
|
||||||
|
Returns post-update total and whether cap is exhausted.
|
||||||
|
"""
|
||||||
|
ym = _current_year_month()
|
||||||
|
|
||||||
|
# Detect dialect to choose between native upsert and portable fallback.
|
||||||
|
dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr]
|
||||||
|
|
||||||
|
if dialect_name == "postgresql":
|
||||||
|
# Native atomic upsert — production path.
|
||||||
|
stmt = (
|
||||||
|
pg_insert(MonthlyTokenUsage)
|
||||||
|
.values(
|
||||||
|
user_id=user_id,
|
||||||
|
year_month=ym,
|
||||||
|
feature=feature,
|
||||||
|
tokens_used=tokens,
|
||||||
|
)
|
||||||
|
.on_conflict_do_update(
|
||||||
|
index_elements=["user_id", "year_month", "feature"],
|
||||||
|
set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens},
|
||||||
|
)
|
||||||
|
.returning(MonthlyTokenUsage.tokens_used)
|
||||||
|
)
|
||||||
|
used: int = (await db.execute(stmt)).scalar_one()
|
||||||
|
await db.commit()
|
||||||
|
else:
|
||||||
|
# Portable fallback — used in tests (SQLite) and any non-PG engine.
|
||||||
|
row = (
|
||||||
|
await db.execute(
|
||||||
|
select(MonthlyTokenUsage).where(
|
||||||
|
MonthlyTokenUsage.user_id == user_id,
|
||||||
|
MonthlyTokenUsage.year_month == ym,
|
||||||
|
MonthlyTokenUsage.feature == feature,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
row = MonthlyTokenUsage(
|
||||||
|
user_id=user_id,
|
||||||
|
year_month=ym,
|
||||||
|
feature=feature,
|
||||||
|
tokens_used=tokens,
|
||||||
|
)
|
||||||
|
db.add(row)
|
||||||
|
else:
|
||||||
|
row.tokens_used += tokens
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(row)
|
||||||
|
used = row.tokens_used
|
||||||
|
|
||||||
|
exhausted = cap is not None and cap != -1 and used >= cap
|
||||||
|
return TokenUsageResult(tokens_used=used, exhausted=exhausted)
|
||||||
@@ -29,6 +29,8 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"realtime_extraction": False, # batch queue (Phase 2)
|
"realtime_extraction": False, # batch queue (Phase 2)
|
||||||
"relational_memory": False, # relational tier (Phase 3) — Pro+
|
"relational_memory": False, # relational tier (Phase 3) — Pro+
|
||||||
"proactive_mining": False, # Power+ only (Phase 5)
|
"proactive_mining": False, # Power+ only (Phase 5)
|
||||||
|
"folder_max_files": 200,
|
||||||
|
"folder_monthly_tokens": 100_000,
|
||||||
},
|
},
|
||||||
"pro": {
|
"pro": {
|
||||||
"agents": -1, # unlimited
|
"agents": -1, # unlimited
|
||||||
@@ -41,6 +43,8 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
||||||
"relational_memory": True, # person/project predicates
|
"relational_memory": True, # person/project predicates
|
||||||
"proactive_mining": False, # Power+ only (Phase 5)
|
"proactive_mining": False, # Power+ only (Phase 5)
|
||||||
|
"folder_max_files": 5000,
|
||||||
|
"folder_monthly_tokens": 2_000_000,
|
||||||
},
|
},
|
||||||
"power": {
|
"power": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
@@ -53,6 +57,8 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"realtime_extraction": True,
|
"realtime_extraction": True,
|
||||||
"relational_memory": True, # all predicates incl. custom
|
"relational_memory": True, # all predicates incl. custom
|
||||||
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||||
|
"folder_max_files": -1, # unlimited
|
||||||
|
"folder_monthly_tokens": -1, # unlimited
|
||||||
},
|
},
|
||||||
"team": {
|
"team": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
@@ -65,6 +71,8 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"realtime_extraction": True,
|
"realtime_extraction": True,
|
||||||
"relational_memory": True, # all predicates incl. custom
|
"relational_memory": True, # all predicates incl. custom
|
||||||
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||||
|
"folder_max_files": -1, # unlimited
|
||||||
|
"folder_monthly_tokens": -1, # unlimited
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,6 +131,13 @@ class TierManager:
|
|||||||
)
|
)
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||||
|
|
||||||
|
def get_feature_value(self, tier: BillingTier, feature: str) -> int:
|
||||||
|
"""Return integer feature value for tier. -1 means unlimited."""
|
||||||
|
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||||
|
if not isinstance(value, int):
|
||||||
|
return 0
|
||||||
|
return value
|
||||||
|
|
||||||
# ── Rate limiting ────────────────────────────────────────────────────
|
# ── Rate limiting ────────────────────────────────────────────────────
|
||||||
|
|
||||||
def get_rate_limit(self, tier: BillingTier) -> int:
|
def get_rate_limit(self, tier: BillingTier) -> int:
|
||||||
|
|||||||
@@ -23,12 +23,12 @@ class Settings(BaseSettings):
|
|||||||
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||||
|
|
||||||
# Per-agent model overrides. Leave empty to fall back to LLM_MODEL.
|
# Per-agent model overrides. Leave empty to fall back to LLM_MODEL.
|
||||||
LLM_MODEL_CLASSIFIER: str = "" # _infer_floating_domain (intent routing)
|
LLM_MODEL_CLASSIFIER: str = "" # classifier (intent routing, future use)
|
||||||
LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream)
|
LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream)
|
||||||
LLM_MODEL_FLOATING_AGENT: str = "" # floating-agent (contextual chat)
|
|
||||||
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
|
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
|
||||||
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
||||||
LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs)
|
LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs)
|
||||||
|
LLM_MODEL_TASK_BRIEF_AGENT: str = "" # task-brief-agent (per-task deep research)
|
||||||
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
||||||
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
|
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
|
||||||
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)
|
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)
|
||||||
@@ -58,6 +58,16 @@ class Settings(BaseSettings):
|
|||||||
# Prod: https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback
|
# Prod: https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback
|
||||||
OAUTH_REDIRECT_URI: str = "http://localhost:8000/api/v1/auth/oauth/google/web-callback"
|
OAUTH_REDIRECT_URI: str = "http://localhost:8000/api/v1/auth/oauth/google/web-callback"
|
||||||
|
|
||||||
|
# Gmail Pub/Sub topic for push notifications.
|
||||||
|
# Full resource name, e.g. "projects/my-project/topics/gmail-push".
|
||||||
|
# Leave empty in dev — setup_watch will skip registration gracefully.
|
||||||
|
GMAIL_PUBSUB_TOPIC: str = ""
|
||||||
|
# OIDC token audience for Pub/Sub push subscription JWT verification.
|
||||||
|
# Set to the service account email or audience string configured in the
|
||||||
|
# Pub/Sub push subscription. Leave empty in dev to skip verification
|
||||||
|
# (a warning is logged — never silent in production).
|
||||||
|
GMAIL_PUBSUB_AUDIENCE: str = ""
|
||||||
|
|
||||||
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
||||||
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
||||||
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from app.core.deep_agent import (
|
|||||||
_relational_memory_injection,
|
_relational_memory_injection,
|
||||||
_run_single_agent_stream,
|
_run_single_agent_stream,
|
||||||
_trace_id_from_context,
|
_trace_id_from_context,
|
||||||
|
build_brief_multi_project_manifest,
|
||||||
)
|
)
|
||||||
from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback
|
from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback
|
||||||
|
|
||||||
@@ -159,6 +160,8 @@ async def run_home_brief(
|
|||||||
Yields (event_type, data) tuples identical to _run_single_agent_stream.
|
Yields (event_type, data) tuples identical to _run_single_agent_stream.
|
||||||
Do NOT post-process output through _normalize_tagged_list_lines.
|
Do NOT post-process output through _normalize_tagged_list_lines.
|
||||||
"""
|
"""
|
||||||
|
from app.agents.folder_agent import FOLDER_TOOLS
|
||||||
|
|
||||||
trace_id = _trace_id_from_context(context)
|
trace_id = _trace_id_from_context(context)
|
||||||
today = date.today().isoformat()
|
today = date.today().isoformat()
|
||||||
language = _resolve_language(context)
|
language = _resolve_language(context)
|
||||||
@@ -171,7 +174,10 @@ async def run_home_brief(
|
|||||||
if today not in system_prompt:
|
if today not in system_prompt:
|
||||||
system_prompt += f"\nToday is {today}."
|
system_prompt += f"\nToday is {today}."
|
||||||
|
|
||||||
tools = _build_read_tools(user_id, trace_id)
|
brief_manifest = await build_brief_multi_project_manifest()
|
||||||
|
system_prompt = system_prompt + ("\n\n" + brief_manifest if brief_manifest else "")
|
||||||
|
|
||||||
|
tools = [*_build_read_tools(user_id, trace_id), *FOLDER_TOOLS]
|
||||||
async for event in _run_single_agent_stream(
|
async for event in _run_single_agent_stream(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Single-agent runners for home and floating chat contexts."""
|
"""Single-agent runners for home and contextual chat contexts."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -7,16 +7,18 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any, Literal
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.agents.client_agent import CLIENT_TOOLS
|
||||||
from app.agents.note_agent import NOTE_TOOLS
|
from app.agents.note_agent import NOTE_TOOLS
|
||||||
from app.agents.project_agent import PROJECT_TOOLS
|
from app.agents.project_agent import PROJECT_TOOLS
|
||||||
|
from app.agents.relations_agent import make_query_relations_tool
|
||||||
from app.agents.task_agent import TASK_TOOLS
|
from app.agents.task_agent import TASK_TOOLS
|
||||||
from app.agents.timeline_agent import TIMELINE_TOOLS
|
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||||
from app.core.agent_session_buffer import session_buffer
|
from app.core.scout_session_buffer import session_buffer
|
||||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
||||||
from app.core.llm import get_agent_llm, model_for_agent
|
from app.core.llm import get_agent_llm, model_for_agent
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
@@ -27,9 +29,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
MAX_HISTORY_TURNS = 20
|
MAX_HISTORY_TURNS = 20
|
||||||
|
|
||||||
FloatingDomainType = Literal["task", "timeline", "project", "node"]
|
|
||||||
FloatingDomainSection = Literal["task", "timeline", "note"]
|
|
||||||
|
|
||||||
# Mapping of core-memory language values to natural-language names for prompts.
|
# Mapping of core-memory language values to natural-language names for prompts.
|
||||||
_LANGUAGE_NAMES: dict[str, str] = {
|
_LANGUAGE_NAMES: dict[str, str] = {
|
||||||
"en": "English", "it": "Italian", "es": "Spanish",
|
"en": "English", "it": "Italian", "es": "Spanish",
|
||||||
@@ -58,6 +57,93 @@ def _language_instruction(context: dict[str, Any]) -> str:
|
|||||||
f"All your output text must be written in {lang}."
|
f"All your output text must be written in {lang}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MANIFEST_TOKEN_BUDGET = 3000 # rough budget for <linked_folder> block
|
||||||
|
|
||||||
|
|
||||||
|
def format_folder_manifest(manifest: dict | None) -> str:
|
||||||
|
"""Format a folder manifest into the <linked_folder> block.
|
||||||
|
|
||||||
|
Truncates by mtime DESC if estimated tokens exceed MANIFEST_TOKEN_BUDGET.
|
||||||
|
Returns empty string if manifest is None or has no files.
|
||||||
|
"""
|
||||||
|
if not manifest or not manifest.get("files"):
|
||||||
|
return ""
|
||||||
|
files = list(manifest["files"])
|
||||||
|
files.sort(key=lambda f: f.get("mtimeMs", 0), reverse=True)
|
||||||
|
|
||||||
|
header = (
|
||||||
|
f"<linked_folder>\npath: {manifest.get('folderPath', '?')} "
|
||||||
|
f"({len(files)} files, scanned {manifest.get('lastScannedAt', '?')})\nfiles:\n"
|
||||||
|
)
|
||||||
|
footer_template = "… {} more files omitted, use read_project_folder_file to access by path\n</linked_folder>"
|
||||||
|
|
||||||
|
char_budget = MANIFEST_TOKEN_BUDGET * 4 # ~4 chars/token
|
||||||
|
body = ""
|
||||||
|
included = 0
|
||||||
|
for f in files:
|
||||||
|
line = f"- /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}\n"
|
||||||
|
if len(header) + len(body) + len(line) + len(footer_template.format(0)) > char_budget:
|
||||||
|
break
|
||||||
|
body += line
|
||||||
|
included += 1
|
||||||
|
omitted = len(files) - included
|
||||||
|
if omitted > 0:
|
||||||
|
return header + body + footer_template.format(omitted)
|
||||||
|
return header + body + "</linked_folder>"
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_project_manifest(project_id: str) -> dict | None:
|
||||||
|
"""Fetch manifest from Electron via execute_on_client. Returns None if unlinked or error."""
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="read_project_folder_manifest",
|
||||||
|
data={"projectId": project_id},
|
||||||
|
)
|
||||||
|
if not result or not result.get("folderPath"):
|
||||||
|
return None
|
||||||
|
return result
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def build_brief_multi_project_manifest() -> str:
|
||||||
|
"""Build a compact multi-project manifest for the daily brief agent.
|
||||||
|
|
||||||
|
Calls execute_on_client('list_projects_with_folder_manifests') and keeps
|
||||||
|
the top 5 most-recently-modified files per project.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="list_projects_with_folder_manifests",
|
||||||
|
data={},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return ""
|
||||||
|
projects = (result or {}).get("projects") or []
|
||||||
|
if not projects:
|
||||||
|
return ""
|
||||||
|
blocks: list[str] = ["<linked_folders>"]
|
||||||
|
any_entry = False
|
||||||
|
for p in projects:
|
||||||
|
all_files = p.get("files", []) or []
|
||||||
|
files = sorted(all_files, key=lambda f: f.get("mtimeMs", 0), reverse=True)[:5]
|
||||||
|
blocks.append(f"project: {p.get('projectName','?')} [{p.get('projectId','?')}]")
|
||||||
|
blocks.append(f" path: {p.get('folderPath','?')} (scanned {p.get('lastScannedAt','?')})")
|
||||||
|
if not all_files:
|
||||||
|
blocks.append(" (no indexed files yet — folder is linked but empty or unscanned)")
|
||||||
|
else:
|
||||||
|
for f in files:
|
||||||
|
blocks.append(f" - /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}")
|
||||||
|
if len(all_files) > 5:
|
||||||
|
blocks.append(f" … {len(all_files) - 5} more files (use read_project_folder_file by relPath)")
|
||||||
|
any_entry = True
|
||||||
|
if not any_entry:
|
||||||
|
return ""
|
||||||
|
blocks.append("</linked_folders>")
|
||||||
|
return "\n".join(blocks)
|
||||||
|
|
||||||
|
|
||||||
def _datetime_context_injection(context: dict[str, Any]) -> str:
|
def _datetime_context_injection(context: dict[str, Any]) -> str:
|
||||||
"""Build a comprehensive DATE CONTEXT block with pre-computed ms-epoch boundaries for common ranges."""
|
"""Build a comprehensive DATE CONTEXT block with pre-computed ms-epoch boundaries for common ranges."""
|
||||||
fp = context.get("format_prefs")
|
fp = context.get("format_prefs")
|
||||||
@@ -265,30 +351,63 @@ For "today" / "tomorrow" queries, prefer list_tasks_due_today / list_timelines_t
|
|||||||
{request_context}\
|
{request_context}\
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_FLOATING_SYSTEM_PROMPT = """\
|
_CONTEXTUAL_SYSTEM_PROMPT = """You are adiuvAI's contextual assistant. The user is working inside the app and has opened a side chat anchored to a specific view ("current view"). Help them act on that view: recap, plan, create entities, answer questions.
|
||||||
You are adiuvAI's floating executive assistant.{user_identity}
|
|
||||||
You are pinned to a specific entity (task, timeline event, project, or note) and you stay strictly within that scope.
|
|
||||||
Be a proactive partner: anticipate the next useful action and close with a concrete suggestion or a clarifying question — but stay terse, one short paragraph at most.
|
|
||||||
|
|
||||||
# How you work
|
Rules:
|
||||||
- Use tools before answering anything factual. Never guess.
|
1. Base context (current view summary) is provided every turn. Treat it as ground truth for ids and names; never invent them.
|
||||||
- Stay in the floating scope (see Request context). If the user asks something outside scope, answer briefly and suggest opening the home assistant.
|
2. ALL reads go through `get_page_details`. The legacy tools `list_projects`, `get_project`, `list_tasks`, `get_task`, `list_notes`, `get_note` are NOT available in this channel — do not attempt to call them. To find an entity by name, call `get_page_details({entityType: 'projects_all' | 'tasks_all' | 'timeline_all'})` to list, then `get_page_details({entityType: '<type>', entityId})` for the full snapshot.
|
||||||
- Match the user's tone preference. Default to warm-but-direct.
|
3. When the user requests an action that creates or updates an entity:
|
||||||
- When the user asks to remember, forget, or update something, use memory tools.
|
- If the current view is a project and no project is specified, use the current project automatically.
|
||||||
|
- If the current view is the global Tasks / Projects / Timeline list and no project is specified, ASK before attaching to any project. Don't silently create orphan entities.
|
||||||
|
4. The current view can change mid-conversation (user navigates). When you see a system message "User navigated to ...", treat the new view as the active context. Prior turns remain visible but the active scope shifts.
|
||||||
|
5. Notes: you can read note bodies via `get_page_details({entityType:'note'})`. You CANNOT edit, summarize-to-replace, or append. Tell the user "note editing is coming in a later release" if asked.
|
||||||
|
6. Be concise. Default to 1-3 short paragraphs. Bullet lists fine. Don't restate the user's request.
|
||||||
|
7. Never expose ids in prose. Use names. Ids only travel through tool calls.
|
||||||
|
|
||||||
# Filter discipline
|
# Date context
|
||||||
- Never set the `assignee` filter on list_tasks/count_tasks unless the user explicitly names a person ("Marco's tasks") or refers to themselves ("my tasks", "assigned to me", "mine").
|
|
||||||
- The user's own name in the User profile block is for context only — it is NOT a default filter.
|
|
||||||
- When in doubt, omit `assignee` and return the global result.
|
|
||||||
|
|
||||||
# Output format
|
|
||||||
Plain text only. Do NOT output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed-id wrappers, and do NOT output <chart> blocks — those are for the home assistant.
|
|
||||||
|
|
||||||
# Date filtering
|
|
||||||
{date_context}
|
{date_context}
|
||||||
|
|
||||||
When filtering by date, take dueDateFrom / dueDateTo (ms epoch UTC) verbatim from the DATE CONTEXT boundary table above. Do NOT compute boundaries from now_ms yourself.
|
# Language
|
||||||
For specific dates not listed, compute local-midnight in the user timezone and convert to UTC ms.
|
{language_instruction}
|
||||||
|
"""
|
||||||
|
|
||||||
|
_TASK_BRIEF_RESEARCH_SYSTEM_PROMPT = """\
|
||||||
|
You are an executive assistant preparing a briefing dossier for your principal before they act on a specific task.
|
||||||
|
Your job: gather all relevant context, synthesize it into a tight actionable dossier, and — if the task requires writing (email, message, document) — produce a ready-to-use draft.{user_identity}
|
||||||
|
|
||||||
|
# Research workflow
|
||||||
|
Follow these steps in order, using tools:
|
||||||
|
1. Read the task fully (title, description, due date, priority, status, project, comments).
|
||||||
|
2. Fetch the parent project (`get_project`) to understand scope, aiSummary, and any linked client.
|
||||||
|
3. If the project has a clientId: call `get_client(id)` to retrieve full client details.
|
||||||
|
4. Call `query_relations` (subject_label=client_name or task subject) to find cross-project connections — e.g. the same client appearing in multiple projects.
|
||||||
|
5. Search associative memory (`search_associative`) and archival memory (`archival_memory_search`) using the task title + client name as query phrases to surface relevant past interactions.
|
||||||
|
6. Read core memory blocks for tone preference, language, and user style: `memory_get("tone_preference")`, `memory_get("language")`.
|
||||||
|
7. Determine task kind: is this a writing task (email reply, message, follow-up, proposal)? If yes, draft a ready-to-send piece.
|
||||||
|
|
||||||
|
# Output structure
|
||||||
|
Write the briefing in the user's language. Use this exact structure:
|
||||||
|
|
||||||
|
**What needs to be done**
|
||||||
|
(1–2 sentences, concrete and specific — what action the user must take)
|
||||||
|
|
||||||
|
**Context you should know**
|
||||||
|
(bullet points covering: client background, related projects, prior interactions, tone/style notes, any relevant deadlines or dependencies)
|
||||||
|
|
||||||
|
**Suggested first step**
|
||||||
|
(one specific, immediately actionable instruction)
|
||||||
|
|
||||||
|
If this is a writing task, append a canvas block at the very end:
|
||||||
|
<canvas kind="email|document|message">
|
||||||
|
...ready-to-use draft here...
|
||||||
|
</canvas>
|
||||||
|
|
||||||
|
Do NOT include the canvas block for non-writing tasks.
|
||||||
|
Do NOT repeat verbatim task fields the user already sees in the UI.
|
||||||
|
Be concrete — no vague advice. Every bullet should be a fact that changes what the user does.
|
||||||
|
|
||||||
|
# Date context
|
||||||
|
{date_context}
|
||||||
|
|
||||||
# Language
|
# Language
|
||||||
{language_instruction}
|
{language_instruction}
|
||||||
@@ -296,25 +415,35 @@ For specific dates not listed, compute local-midnight in the user timezone and c
|
|||||||
# Known people & projects
|
# Known people & projects
|
||||||
{relational_memory}
|
{relational_memory}
|
||||||
|
|
||||||
# Behavioral hints
|
|
||||||
{proactive_hints}
|
|
||||||
|
|
||||||
# Request context
|
# Request context
|
||||||
{request_context}\
|
{request_context}\
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_FLOATING_DOMAIN_CLASSIFIER_PROMPT = (
|
_TASK_BRIEF_FOLLOWUP_SYSTEM_PROMPT = """\
|
||||||
"You are a strict domain classifier for websocket floating requests. "
|
You are an executive assistant continuing a conversation with your principal.
|
||||||
"Return ONLY a JSON object with keys: type, id, section. "
|
You have already prepared and delivered a research briefing for the active task. The user has read it.{user_identity}
|
||||||
"Allowed type values: task, timeline, project, node. "
|
|
||||||
"Allowed section values: task, timeline, note, or null. "
|
|
||||||
"Rules: infer from user message intent first; do not blindly trust scope.type. "
|
|
||||||
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
|
|
||||||
"If project id is unknown but context.resolved_project_id exists, use it as id. "
|
|
||||||
"If id is unknown, use null. "
|
|
||||||
"No markdown, no prose, JSON only."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
Your briefing:
|
||||||
|
---
|
||||||
|
{briefing_context}
|
||||||
|
---
|
||||||
|
|
||||||
|
Continue from here. Do NOT repeat the briefing. Refer to it when relevant.
|
||||||
|
Help the user execute: edit drafts, refine wording, look up additional details, plan next steps.
|
||||||
|
Stay terse — your principal is a busy executive.
|
||||||
|
|
||||||
|
# Date context
|
||||||
|
{date_context}
|
||||||
|
|
||||||
|
# Language
|
||||||
|
{language_instruction}
|
||||||
|
|
||||||
|
# Known people & projects
|
||||||
|
{relational_memory}
|
||||||
|
|
||||||
|
# Request context
|
||||||
|
{request_context}\
|
||||||
|
"""
|
||||||
|
|
||||||
def _as_text(content: Any) -> str:
|
def _as_text(content: Any) -> str:
|
||||||
if content is None:
|
if content is None:
|
||||||
@@ -393,6 +522,55 @@ def _all_tools() -> list[Any]:
|
|||||||
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
|
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Contextual sidebar tools ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_page_details(
|
||||||
|
entity_type: str = "",
|
||||||
|
entity_id: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Fetch full details for the entity currently in view.
|
||||||
|
|
||||||
|
entity_type: one of 'project' | 'task' | 'note' | 'timeline_event' |
|
||||||
|
'tasks_all' | 'projects_all' | 'timeline_all'.
|
||||||
|
entity_id: UUID of the entity for singular entity views. Omit for list views.
|
||||||
|
|
||||||
|
The Electron drizzle-executor fulfils this op against local SQLite and
|
||||||
|
returns the row(s) as a JSON tool result.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="get_page_details",
|
||||||
|
table=entity_type or "unknown",
|
||||||
|
data={"entityId": entity_id or None},
|
||||||
|
)
|
||||||
|
if not result:
|
||||||
|
return "No details found."
|
||||||
|
return str(result)
|
||||||
|
|
||||||
|
|
||||||
|
def _contextual_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
|
"""Return the tool palette for the contextual sidebar agent.
|
||||||
|
|
||||||
|
Read ops go through get_page_details only — legacy list_*/get_* tools
|
||||||
|
return shallow snapshots and cause the agent to under-answer (see
|
||||||
|
smoke trace 0b46841484ba7d024ed9f8d5ac8b1df0). Writes are limited
|
||||||
|
to entity creation + task update; note edits are next-sprint.
|
||||||
|
"""
|
||||||
|
from app.agents.note_agent import create_note # noqa: PLC0415
|
||||||
|
from app.agents.task_agent import create_task, update_task # noqa: PLC0415
|
||||||
|
from app.agents.timeline_agent import create_timeline # noqa: PLC0415
|
||||||
|
|
||||||
|
return [
|
||||||
|
get_page_details,
|
||||||
|
create_task,
|
||||||
|
update_task,
|
||||||
|
create_note,
|
||||||
|
create_timeline,
|
||||||
|
*_memory_tools(user_id, trace_id),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
|
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
|
||||||
debug = context.get("_debug")
|
debug = context.get("_debug")
|
||||||
if isinstance(debug, dict):
|
if isinstance(debug, dict):
|
||||||
@@ -495,70 +673,6 @@ def _normalize_tagged_list_lines(text: str, message: str) -> str:
|
|||||||
return "\n".join(output_lines)
|
return "\n".join(output_lines)
|
||||||
|
|
||||||
|
|
||||||
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
|
|
||||||
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
|
|
||||||
_FLOATING_EMPTY_FALLBACK = "No results found."
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_floating_markup_fragment(text: str) -> str:
|
|
||||||
if not text:
|
|
||||||
return text
|
|
||||||
cleaned = _GENERIC_TAG_RE.sub("", text)
|
|
||||||
return _BRACKETED_ID_RE.sub("", cleaned)
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_floating_markup(text: str) -> str:
|
|
||||||
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
|
|
||||||
if not text:
|
|
||||||
return text
|
|
||||||
|
|
||||||
cleaned = _strip_floating_markup_fragment(text)
|
|
||||||
# Collapse excessive spaces introduced by tag/id removal while preserving lines.
|
|
||||||
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
|
|
||||||
return "\n".join(line for line in lines if line)
|
|
||||||
|
|
||||||
|
|
||||||
def _fallback_from_raw_floating_text(raw_text: str) -> str:
|
|
||||||
fallback = _strip_floating_markup_fragment(raw_text or "")
|
|
||||||
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
|
|
||||||
return fallback or _FLOATING_EMPTY_FALLBACK
|
|
||||||
|
|
||||||
|
|
||||||
class _FloatingStreamSanitizer:
|
|
||||||
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._pending = ""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _split_safe_boundary(text: str) -> tuple[str, str]:
|
|
||||||
boundary = len(text)
|
|
||||||
|
|
||||||
last_lt = text.rfind("<")
|
|
||||||
if last_lt != -1 and ">" not in text[last_lt:]:
|
|
||||||
boundary = min(boundary, last_lt)
|
|
||||||
|
|
||||||
last_lb = text.rfind("[")
|
|
||||||
if last_lb != -1 and "]" not in text[last_lb:]:
|
|
||||||
boundary = min(boundary, last_lb)
|
|
||||||
|
|
||||||
if boundary == len(text):
|
|
||||||
return text, ""
|
|
||||||
return text[:boundary], text[boundary:]
|
|
||||||
|
|
||||||
def feed(self, chunk: str) -> str:
|
|
||||||
combined = f"{self._pending}{chunk}"
|
|
||||||
safe_text, self._pending = self._split_safe_boundary(combined)
|
|
||||||
return _strip_floating_markup_fragment(safe_text)
|
|
||||||
|
|
||||||
def finalize(self) -> str:
|
|
||||||
# Drop dangling unfinished wrappers at the very end.
|
|
||||||
tail = re.sub(r"<[^>\n]*$", "", self._pending)
|
|
||||||
tail = re.sub(r"\[[^\]\n]*$", "", tail)
|
|
||||||
self._pending = ""
|
|
||||||
return _strip_floating_markup_fragment(tail)
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_memory_label(path_or_label: str) -> str:
|
def _normalize_memory_label(path_or_label: str) -> str:
|
||||||
value = path_or_label.strip()
|
value = path_or_label.strip()
|
||||||
if value.startswith("/memories/"):
|
if value.startswith("/memories/"):
|
||||||
@@ -679,6 +793,25 @@ def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
|||||||
lines = [f"- {item}" for item in results]
|
lines = [f"- {item}" for item in results]
|
||||||
return "Recall memory results:\n" + "\n".join(lines)
|
return "Recall memory results:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def search_associative(query: str, limit: int = 5) -> str:
|
||||||
|
"""Semantic search across associative (archival) memory for a given query.
|
||||||
|
|
||||||
|
Use this to surface long-term memories related to a topic, client, or task
|
||||||
|
that may not appear in recent episodes.
|
||||||
|
|
||||||
|
query: natural-language search phrase.
|
||||||
|
limit: max results (default 5).
|
||||||
|
"""
|
||||||
|
logger.info("deep_agent: search_associative trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
results = await memory.search_archival(user_id, query, top_k=limit)
|
||||||
|
if not results:
|
||||||
|
return "No associative memory results found."
|
||||||
|
lines = [f"- {item}" for item in results]
|
||||||
|
return "Associative memory results:\n" + "\n".join(lines)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
memory_list_blocks,
|
memory_list_blocks,
|
||||||
memory_get,
|
memory_get,
|
||||||
@@ -689,182 +822,37 @@ def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
|||||||
archival_memory_insert,
|
archival_memory_insert,
|
||||||
archival_memory_search,
|
archival_memory_search,
|
||||||
conversation_search,
|
conversation_search,
|
||||||
|
search_associative,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _read_only_memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
def _read_only_memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
"""Return memory tools that only read — safe for the read-only brief-agent subset."""
|
"""Return memory tools that only read — safe for the read-only brief-agent subset."""
|
||||||
all_mem = _memory_tools(user_id, trace_id)
|
all_mem = _memory_tools(user_id, trace_id)
|
||||||
_read_names = {"memory_list_blocks", "memory_get", "archival_memory_search", "conversation_search"}
|
_read_names = {
|
||||||
|
"memory_list_blocks", "memory_get", "archival_memory_search",
|
||||||
|
"conversation_search", "search_associative",
|
||||||
|
}
|
||||||
return [t for t in all_mem if t.name in _read_names]
|
return [t for t in all_mem if t.name in _read_names]
|
||||||
|
|
||||||
|
|
||||||
|
def _brief_research_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
|
"""Return the full tool palette for Stage-1 task brief research (read-only)."""
|
||||||
|
return [
|
||||||
|
*TASK_TOOLS,
|
||||||
|
*PROJECT_TOOLS,
|
||||||
|
*NOTE_TOOLS,
|
||||||
|
*TIMELINE_TOOLS,
|
||||||
|
*CLIENT_TOOLS,
|
||||||
|
*_read_only_memory_tools(user_id, trace_id),
|
||||||
|
make_query_relations_tool(user_id, trace_id),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
|
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
|
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
|
||||||
|
|
||||||
|
|
||||||
def _detect_domain_section(message: str) -> FloatingDomainSection | None:
|
|
||||||
lowered = message.lower()
|
|
||||||
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
|
|
||||||
return "timeline"
|
|
||||||
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
|
|
||||||
return "task"
|
|
||||||
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
|
|
||||||
return "note"
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
|
|
||||||
type_raw = str(payload.get("type") or "").strip().lower()
|
|
||||||
domain_type: FloatingDomainType = "task"
|
|
||||||
if type_raw in {"task", "timeline", "project", "node"}:
|
|
||||||
domain_type = type_raw
|
|
||||||
|
|
||||||
id_value = payload.get("id")
|
|
||||||
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
|
|
||||||
if domain_type == "project" and not domain_id:
|
|
||||||
domain_id = fallback_id
|
|
||||||
|
|
||||||
section_raw = payload.get("section")
|
|
||||||
section: FloatingDomainSection | None = None
|
|
||||||
if isinstance(section_raw, str):
|
|
||||||
section_candidate = section_raw.strip().lower()
|
|
||||||
if section_candidate in {"task", "timeline", "note"}:
|
|
||||||
section = section_candidate
|
|
||||||
|
|
||||||
if domain_type != "project":
|
|
||||||
section = None
|
|
||||||
|
|
||||||
return {
|
|
||||||
"type": domain_type,
|
|
||||||
"id": domain_id,
|
|
||||||
"section": section,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_json_object(text: str) -> dict[str, Any] | None:
|
|
||||||
raw = text.strip()
|
|
||||||
if not raw:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
parsed = json.loads(raw)
|
|
||||||
return parsed if isinstance(parsed, dict) else None
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
|
||||||
if not match:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
parsed = json.loads(match.group(0))
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return None
|
|
||||||
return parsed if isinstance(parsed, dict) else None
|
|
||||||
|
|
||||||
|
|
||||||
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
|
|
||||||
section = _detect_domain_section(message)
|
|
||||||
scope = context.get("scope") if isinstance(context, dict) else None
|
|
||||||
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
|
||||||
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
|
||||||
|
|
||||||
if isinstance(scope, dict):
|
|
||||||
scope_type = str(scope.get("type") or "").strip().lower()
|
|
||||||
scope_id = scope.get("id")
|
|
||||||
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
|
|
||||||
|
|
||||||
if scope_type in {"task", "tasks"}:
|
|
||||||
return {"type": "task", "id": scope_id_value, "section": None}
|
|
||||||
if scope_type in {"project", "projects"}:
|
|
||||||
project_scope_id = scope_id_value or project_id
|
|
||||||
return {
|
|
||||||
"type": "project",
|
|
||||||
"id": project_scope_id,
|
|
||||||
"section": section,
|
|
||||||
}
|
|
||||||
if scope_type in {"note", "notes"}:
|
|
||||||
return {
|
|
||||||
"type": "node",
|
|
||||||
"id": scope_id_value,
|
|
||||||
"section": None,
|
|
||||||
}
|
|
||||||
if scope_type in {"timeline", "timelines"}:
|
|
||||||
return {"type": "timeline", "id": scope_id_value, "section": None}
|
|
||||||
|
|
||||||
lowered = message.lower()
|
|
||||||
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
|
|
||||||
return {
|
|
||||||
"type": "project",
|
|
||||||
"id": project_id,
|
|
||||||
"section": section,
|
|
||||||
}
|
|
||||||
if section == "timeline":
|
|
||||||
return {"type": "timeline", "id": None, "section": None}
|
|
||||||
if section == "note":
|
|
||||||
return {"type": "node", "id": None, "section": None}
|
|
||||||
return {"type": "task", "id": None, "section": None}
|
|
||||||
|
|
||||||
|
|
||||||
async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[str, str | None]:
|
|
||||||
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
|
||||||
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
|
||||||
|
|
||||||
classifier_context = {
|
|
||||||
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
|
|
||||||
"resolved_project_id": project_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
llm = get_agent_llm("classifier")
|
|
||||||
classifier_messages = [
|
|
||||||
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=(
|
|
||||||
f"Message:\n{message}\n\n"
|
|
||||||
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
|
|
||||||
)
|
|
||||||
),
|
|
||||||
]
|
|
||||||
lf = get_langfuse()
|
|
||||||
_, classifier_prompt_obj = get_prompt_or_fallback(
|
|
||||||
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_PROMPT
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract user/session from context for Langfuse attribution
|
|
||||||
_debug = context.get("_debug") if isinstance(context, dict) else None
|
|
||||||
_lf_user = (_debug or {}).get("user_id") if isinstance(_debug, dict) else None
|
|
||||||
_lf_session = (_debug or {}).get("session_id") if isinstance(_debug, dict) else None
|
|
||||||
|
|
||||||
with langfuse_context(user_id=_lf_user, session_id=_lf_session):
|
|
||||||
if lf:
|
|
||||||
with lf.start_as_current_observation(
|
|
||||||
as_type="generation",
|
|
||||||
name="floating-classifier",
|
|
||||||
model=model_for_agent("classifier"),
|
|
||||||
prompt=classifier_prompt_obj,
|
|
||||||
input=classifier_messages,
|
|
||||||
) as gen:
|
|
||||||
response = await llm.ainvoke(classifier_messages)
|
|
||||||
gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
|
||||||
else:
|
|
||||||
response = await llm.ainvoke(classifier_messages)
|
|
||||||
parsed = _parse_json_object(_as_text(response.content))
|
|
||||||
if parsed is not None:
|
|
||||||
domain = _normalize_domain_payload(parsed, project_id)
|
|
||||||
logger.info(
|
|
||||||
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
|
|
||||||
domain.get("type"),
|
|
||||||
domain.get("id"),
|
|
||||||
domain.get("section"),
|
|
||||||
)
|
|
||||||
return domain
|
|
||||||
logger.warning("deep_agent: floating_domain classifier returned non-json output")
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
|
|
||||||
|
|
||||||
return _infer_floating_domain_rule_based(message, context)
|
|
||||||
|
|
||||||
|
|
||||||
def _history_to_messages(history: list[dict[str, str]] | None) -> list[Any]:
|
def _history_to_messages(history: list[dict[str, str]] | None) -> list[Any]:
|
||||||
if not history:
|
if not history:
|
||||||
return []
|
return []
|
||||||
@@ -1193,32 +1181,30 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
|||||||
return _normalize_tagged_list_lines(response, message)
|
return _normalize_tagged_list_lines(response, message)
|
||||||
|
|
||||||
|
|
||||||
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
|
|
||||||
prepared_context = await _prepare_context(message, context)
|
|
||||||
domain = await _infer_floating_domain(message, prepared_context)
|
|
||||||
system_prompt, langfuse_prompt = _build_system_prompt("floating_system", _FLOATING_SYSTEM_PROMPT, prepared_context)
|
|
||||||
response = await _run_single_agent(
|
|
||||||
user_id=user_id,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
message=message,
|
|
||||||
context=prepared_context,
|
|
||||||
langfuse_prompt=langfuse_prompt,
|
|
||||||
agent_name="floating-agent",
|
|
||||||
conversation_history=context.get("conversation_history"),
|
|
||||||
)
|
|
||||||
sanitized = _strip_floating_markup(response)
|
|
||||||
if not sanitized and response:
|
|
||||||
sanitized = _fallback_from_raw_floating_text(response)
|
|
||||||
return sanitized, domain
|
|
||||||
|
|
||||||
|
|
||||||
async def run_home_stream(
|
async def run_home_stream(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
context: dict[str, Any],
|
context: dict[str, Any],
|
||||||
|
project_id: str | None = None,
|
||||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
from app.agents.folder_agent import FOLDER_TOOLS
|
||||||
|
|
||||||
prepared_context = await _prepare_context(message, context)
|
prepared_context = await _prepare_context(message, context)
|
||||||
system_prompt, langfuse_prompt = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, prepared_context)
|
system_prompt, langfuse_prompt = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, prepared_context)
|
||||||
|
|
||||||
|
manifest_block = ""
|
||||||
|
if project_id:
|
||||||
|
manifest = await _fetch_project_manifest(project_id)
|
||||||
|
manifest_block = format_folder_manifest(manifest)
|
||||||
|
if not manifest_block:
|
||||||
|
# No specific project context — surface all linked folders so the agent
|
||||||
|
# can answer questions like "tell me about project X" using its files.
|
||||||
|
manifest_block = await build_brief_multi_project_manifest()
|
||||||
|
system_prompt = system_prompt + ("\n\n" + manifest_block if manifest_block else "")
|
||||||
|
|
||||||
|
trace_id = _trace_id_from_context(prepared_context)
|
||||||
|
tools = [*_all_tools_for_user(user_id, trace_id), *FOLDER_TOOLS]
|
||||||
|
|
||||||
text_chunks: list[str] = []
|
text_chunks: list[str] = []
|
||||||
async for event in _run_single_agent_stream(
|
async for event in _run_single_agent_stream(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -1227,6 +1213,7 @@ async def run_home_stream(
|
|||||||
context=prepared_context,
|
context=prepared_context,
|
||||||
langfuse_prompt=langfuse_prompt,
|
langfuse_prompt=langfuse_prompt,
|
||||||
agent_name="home-agent",
|
agent_name="home-agent",
|
||||||
|
tools=tools,
|
||||||
conversation_history=context.get("conversation_history"),
|
conversation_history=context.get("conversation_history"),
|
||||||
):
|
):
|
||||||
event_type, data = event
|
event_type, data = event
|
||||||
@@ -1240,47 +1227,99 @@ async def run_home_stream(
|
|||||||
yield "token", normalized
|
yield "token", normalized
|
||||||
|
|
||||||
|
|
||||||
async def run_floating_stream(
|
async def run_contextual_stream(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
context: dict[str, Any],
|
context: dict[str, Any],
|
||||||
|
scope: "ContextualScope", # type: ignore[name-defined]
|
||||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
prepared_context = await _prepare_context(message, context)
|
"""Run the contextual agent for a single user turn.
|
||||||
domain = await _infer_floating_domain(message, prepared_context)
|
|
||||||
yield "floating_domain", domain
|
Injects the rendered scope block into the system prompt and exposes
|
||||||
|
the contextual tool set.
|
||||||
|
Note-edit tools (propose_note_edit) are intentionally excluded.
|
||||||
|
|
||||||
|
*context contract*: callers MUST include ``context["_debug"]["session_id"]``
|
||||||
|
(a non-empty str) so that ``_session_id_from_context`` can extract it for
|
||||||
|
tracing and episode storage downstream. The WS handler in device_ws.py
|
||||||
|
satisfies this by always populating ``_debug`` before calling this function.
|
||||||
|
"""
|
||||||
|
from app.schemas.contextual import ContextualScope, render_scope_block # noqa: PLC0415
|
||||||
|
|
||||||
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
trace_id = _trace_id_from_context(prepared_context)
|
||||||
|
|
||||||
|
system_prompt, langfuse_prompt = _build_system_prompt(
|
||||||
|
"contextual_system", _CONTEXTUAL_SYSTEM_PROMPT, prepared_context,
|
||||||
|
)
|
||||||
|
scope_block = render_scope_block(scope)
|
||||||
|
system_prompt = system_prompt + f"\n\n## Current view\n{scope_block}"
|
||||||
|
|
||||||
|
tools = _contextual_tools(user_id, trace_id)
|
||||||
|
|
||||||
system_prompt, langfuse_prompt = _build_system_prompt("floating_system", _FLOATING_SYSTEM_PROMPT, prepared_context)
|
|
||||||
sanitizer = _FloatingStreamSanitizer()
|
|
||||||
emitted_sanitized = False
|
|
||||||
raw_chunks: list[str] = []
|
|
||||||
async for event in _run_single_agent_stream(
|
async for event in _run_single_agent_stream(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
message=message,
|
message=message,
|
||||||
context=prepared_context,
|
context=prepared_context,
|
||||||
langfuse_prompt=langfuse_prompt,
|
langfuse_prompt=langfuse_prompt,
|
||||||
agent_name="floating-agent",
|
agent_name="contextual-agent",
|
||||||
|
tools=tools,
|
||||||
conversation_history=context.get("conversation_history"),
|
conversation_history=context.get("conversation_history"),
|
||||||
):
|
):
|
||||||
event_type, data = event
|
|
||||||
if event_type != "token":
|
|
||||||
yield event
|
yield event
|
||||||
continue
|
|
||||||
|
|
||||||
raw_chunk = str(data or "")
|
|
||||||
raw_chunks.append(raw_chunk)
|
|
||||||
sanitized_chunk = sanitizer.feed(raw_chunk)
|
|
||||||
if sanitized_chunk:
|
|
||||||
emitted_sanitized = True
|
|
||||||
yield "token", sanitized_chunk
|
|
||||||
|
|
||||||
tail = sanitizer.finalize()
|
async def run_task_brief_research_stream(
|
||||||
if tail:
|
user_id: str,
|
||||||
emitted_sanitized = True
|
task_id: str,
|
||||||
yield "token", tail
|
context: dict[str, Any],
|
||||||
|
project_id: str | None = None,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
"""Stage-1 executive assistant: deep research for one task.
|
||||||
|
|
||||||
if not emitted_sanitized and raw_chunks:
|
Yields ``("token", chunk)`` events like other stream runners.
|
||||||
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
|
The final concatenated text may contain a ``<canvas kind="...">...</canvas>`` block
|
||||||
|
which the WS handler strips and emits as a ``canvas_draft`` mutation.
|
||||||
|
"""
|
||||||
|
from app.agents.folder_agent import FOLDER_TOOLS
|
||||||
|
|
||||||
|
prepared_context = await _prepare_context(f"task:{task_id}", context)
|
||||||
|
tools = [*_brief_research_tools(user_id, _trace_id_from_context(prepared_context)), *FOLDER_TOOLS]
|
||||||
|
|
||||||
|
# Inject task_id so the agent knows what to look up first.
|
||||||
|
research_message = (
|
||||||
|
f"Prepare a briefing dossier for task ID: {task_id}\n"
|
||||||
|
"Follow the research workflow: read the task, then project, then client, "
|
||||||
|
"then cross-project relations, then relevant memory. "
|
||||||
|
"End with a concrete suggested first step. "
|
||||||
|
"If this is a writing task, include a <canvas kind=\"...\"> draft."
|
||||||
|
)
|
||||||
|
|
||||||
|
system_prompt, langfuse_prompt = _build_system_prompt(
|
||||||
|
"task_brief_research_system",
|
||||||
|
_TASK_BRIEF_RESEARCH_SYSTEM_PROMPT,
|
||||||
|
prepared_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
manifest_block = ""
|
||||||
|
if project_id:
|
||||||
|
manifest = await _fetch_project_manifest(project_id)
|
||||||
|
manifest_block = format_folder_manifest(manifest)
|
||||||
|
system_prompt = system_prompt + ("\n\n" + manifest_block if manifest_block else "")
|
||||||
|
|
||||||
|
async for event in _run_single_agent_stream(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message=research_message,
|
||||||
|
context=prepared_context,
|
||||||
|
max_steps=12,
|
||||||
|
langfuse_prompt=langfuse_prompt,
|
||||||
|
agent_name="task-brief-agent",
|
||||||
|
tools=tools,
|
||||||
|
conversation_history=None,
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
async def update_core_memory(user_id: str, key: str, value: str) -> None:
|
async def update_core_memory(user_id: str, key: str, value: str) -> None:
|
||||||
|
|||||||
183
app/core/folder_indexer.py
Normal file
183
app/core/folder_indexer.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""Per-file summarisation for project folder integration."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from pypdf import PdfReader
|
||||||
|
from docx import Document as DocxDocument
|
||||||
|
|
||||||
|
from app.core.langfuse_client import (
|
||||||
|
compile_prompt,
|
||||||
|
extract_usage,
|
||||||
|
get_langfuse,
|
||||||
|
get_prompt_or_fallback,
|
||||||
|
)
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
|
||||||
|
_TEXT_FALLBACK = (
|
||||||
|
"You are summarising a file for an AI assistant that helps the user manage a project.\n"
|
||||||
|
"Produce a single sentence (<=30 words, <=200 chars) that captures the file's purpose "
|
||||||
|
"and most important detail.\nFile extension: {ext}\nFile name: {name}\nContent (truncated if long):\n{content}"
|
||||||
|
)
|
||||||
|
_IMAGE_FALLBACK = (
|
||||||
|
"You are summarising an image attached to a project folder.\n"
|
||||||
|
"Produce a single sentence (<=30 words, <=200 chars) describing what the image shows "
|
||||||
|
"and any obvious purpose (logo, screenshot, diagram, photo of a whiteboard, etc.)."
|
||||||
|
)
|
||||||
|
_MAX_INPUT_CHARS = 6000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IndexResult:
|
||||||
|
summary: str
|
||||||
|
tokens_used: int
|
||||||
|
|
||||||
|
|
||||||
|
async def _llm_text(messages: list) -> object:
|
||||||
|
"""Make the LLM call for text summarisation.
|
||||||
|
|
||||||
|
Defined as a standalone async function so tests can patch it cleanly
|
||||||
|
without needing to mock the LLM object itself.
|
||||||
|
"""
|
||||||
|
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
|
||||||
|
return await llm.ainvoke(messages)
|
||||||
|
|
||||||
|
|
||||||
|
async def _llm_vision(messages: list) -> object:
|
||||||
|
"""Make the LLM call for vision (image) summarisation.
|
||||||
|
|
||||||
|
Accepts the message list and returns the response directly, mirroring
|
||||||
|
the ``_llm_text`` caller pattern so tests can patch it at the module level.
|
||||||
|
"""
|
||||||
|
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
|
||||||
|
return await llm.ainvoke(messages)
|
||||||
|
|
||||||
|
|
||||||
|
async def summarize_image(*, image_b64: str, mime: str, file_name: str | None = None) -> IndexResult:
|
||||||
|
"""Return a compact summary of an image file using vision.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
image_b64:
|
||||||
|
Base64-encoded image bytes.
|
||||||
|
mime:
|
||||||
|
MIME type of the image, e.g. ``"image/png"``.
|
||||||
|
file_name:
|
||||||
|
Optional file name, attached to the Langfuse trace as input metadata.
|
||||||
|
"""
|
||||||
|
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_image", _IMAGE_FALLBACK)
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=template),
|
||||||
|
HumanMessage(content=[
|
||||||
|
{"type": "text", "text": "Summarise this image."},
|
||||||
|
{"type": "image_url", "image_url": {"url": f"data:{mime};base64,{image_b64}"}},
|
||||||
|
]),
|
||||||
|
]
|
||||||
|
lf = get_langfuse()
|
||||||
|
if lf is not None:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="folder-summarize-image",
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
prompt=prompt_obj,
|
||||||
|
input={"file_name": file_name, "mime": mime},
|
||||||
|
) as gen:
|
||||||
|
response = await _llm_vision(messages)
|
||||||
|
usage = extract_usage(response)
|
||||||
|
gen.update(output=response.content, usage_details=usage)
|
||||||
|
else:
|
||||||
|
response = await _llm_vision(messages)
|
||||||
|
usage = extract_usage(response)
|
||||||
|
summary = (response.content or "").strip()[:500]
|
||||||
|
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
|
||||||
|
|
||||||
|
|
||||||
|
async def summarize_text(*, content: str, ext: str, name: str) -> IndexResult:
|
||||||
|
"""Return a compact summary of a text file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
content:
|
||||||
|
Raw text content of the file (will be truncated to _MAX_INPUT_CHARS).
|
||||||
|
ext:
|
||||||
|
File extension including the leading dot, e.g. ``".md"``.
|
||||||
|
name:
|
||||||
|
File name, e.g. ``"kickoff.md"``.
|
||||||
|
"""
|
||||||
|
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_text", _TEXT_FALLBACK)
|
||||||
|
truncated = content[:_MAX_INPUT_CHARS]
|
||||||
|
compiled = compile_prompt(template, prompt_obj, ext=ext, name=name, content=truncated)
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=compiled),
|
||||||
|
HumanMessage(content="Summarise this file."),
|
||||||
|
]
|
||||||
|
lf = get_langfuse()
|
||||||
|
if lf is not None:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="folder-summarize-text",
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
prompt=prompt_obj,
|
||||||
|
input={"file_name": name, "ext": ext, "content_chars": len(truncated)},
|
||||||
|
) as gen:
|
||||||
|
response = await _llm_text(messages)
|
||||||
|
usage = extract_usage(response)
|
||||||
|
gen.update(output=response.content, usage_details=usage)
|
||||||
|
else:
|
||||||
|
response = await _llm_text(messages)
|
||||||
|
usage = extract_usage(response)
|
||||||
|
summary = (response.content or "").strip()[:500]
|
||||||
|
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_pdf_text(pdf_b64: str) -> str:
|
||||||
|
buf = io.BytesIO(base64.b64decode(pdf_b64))
|
||||||
|
reader = PdfReader(buf)
|
||||||
|
parts: list[str] = []
|
||||||
|
for page in reader.pages:
|
||||||
|
try:
|
||||||
|
parts.append(page.extract_text() or "")
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return "\n".join(parts).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_docx_text(docx_b64: str) -> str:
|
||||||
|
buf = io.BytesIO(base64.b64decode(docx_b64))
|
||||||
|
doc = DocxDocument(buf)
|
||||||
|
return "\n".join(p.text for p in doc.paragraphs if p.text).strip()
|
||||||
|
|
||||||
|
|
||||||
|
async def summarize_pdf(*, pdf_b64: str, name: str) -> IndexResult:
|
||||||
|
"""Return a compact summary of a PDF file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pdf_b64:
|
||||||
|
Base64-encoded PDF bytes.
|
||||||
|
name:
|
||||||
|
File name, e.g. ``"report.pdf"``.
|
||||||
|
"""
|
||||||
|
text = _extract_pdf_text(pdf_b64)
|
||||||
|
if not text:
|
||||||
|
return IndexResult(summary="Could not extract text", tokens_used=0)
|
||||||
|
return await summarize_text(content=text, ext=".pdf", name=name)
|
||||||
|
|
||||||
|
|
||||||
|
async def summarize_docx(*, docx_b64: str, name: str) -> IndexResult:
|
||||||
|
"""Return a compact summary of a DOCX file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
docx_b64:
|
||||||
|
Base64-encoded DOCX bytes.
|
||||||
|
name:
|
||||||
|
File name, e.g. ``"spec.docx"``.
|
||||||
|
"""
|
||||||
|
text = _extract_docx_text(docx_b64)
|
||||||
|
if not text:
|
||||||
|
return IndexResult(summary="Could not extract text", tokens_used=0)
|
||||||
|
return await summarize_text(content=text, ext=".docx", name=name)
|
||||||
@@ -103,14 +103,15 @@ def get_llm(
|
|||||||
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
||||||
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
|
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
|
||||||
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
|
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
|
||||||
"floating-agent": lambda: settings.LLM_MODEL_FLOATING_AGENT or settings.LLM_MODEL,
|
|
||||||
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
||||||
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
||||||
"brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL,
|
"brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL,
|
||||||
|
"task-brief-agent": lambda: settings.LLM_MODEL_TASK_BRIEF_AGENT or settings.LLM_MODEL,
|
||||||
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
||||||
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
|
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
|
||||||
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
|
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
|
||||||
"memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL,
|
"memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL,
|
||||||
|
"note-summarizer": lambda: "gpt-4o-mini",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
51
app/core/note_summarizer.py
Normal file
51
app/core/note_summarizer.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""Note summarizer — generates a compact AI summary for a note.
|
||||||
|
|
||||||
|
Called fire-and-forget from create_note / update_note tools so the
|
||||||
|
``notes.ai_summary`` column stays current without blocking the agent loop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
from app.core.langfuse_client import get_prompt_or_fallback
|
||||||
|
from app.core.llm import get_agent_llm
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_FALLBACK_PROMPT = """\
|
||||||
|
Summarize this note in <=250 characters. Be terse and dense.
|
||||||
|
Keep proper nouns, dates, decisions, and action items.
|
||||||
|
Do not start with "This note".
|
||||||
|
Respond with the summary text only — no intro, no labels.
|
||||||
|
|
||||||
|
Title: {title}
|
||||||
|
Content: {content}"""
|
||||||
|
|
||||||
|
_MAX_CONTENT_CHARS = 4000
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_note_summary(title: str, content: str) -> str:
|
||||||
|
"""Return a <=250-char summary of *title* + *content*.
|
||||||
|
|
||||||
|
Uses the Langfuse ``note_summary`` prompt (hot-swappable) with a local
|
||||||
|
fallback. Truncates *content* to 4000 chars before sending to avoid
|
||||||
|
token waste on large notes.
|
||||||
|
"""
|
||||||
|
template, _ = get_prompt_or_fallback("note_summary", _FALLBACK_PROMPT)
|
||||||
|
trimmed = content[:_MAX_CONTENT_CHARS]
|
||||||
|
system_prompt = template.format(title=title, content=trimmed)
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = get_agent_llm("note-summarizer")
|
||||||
|
response = await llm.ainvoke([
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(content="Generate the summary."),
|
||||||
|
])
|
||||||
|
text = response.content if isinstance(response.content, str) else ""
|
||||||
|
return text.strip()[:250]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("note_summarizer: failed to generate summary: %s", exc)
|
||||||
|
return ""
|
||||||
@@ -2,12 +2,36 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
from app.schemas import WsStreamEnd, WsStreamStart, WsStreamText
|
||||||
|
|
||||||
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
# Matches <canvas kind="...">...</canvas> blocks (single-line or multiline).
|
||||||
|
_CANVAS_BLOCK_RE = re.compile(
|
||||||
|
r'<canvas\s+kind=["\']([^"\']+)["\']>(.*?)</canvas>',
|
||||||
|
re.DOTALL | re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_canvas_block(text: str) -> tuple[str, str | None, str | None]:
|
||||||
|
"""Strip the first <canvas kind="...">...</canvas> block from *text*.
|
||||||
|
|
||||||
|
Returns ``(visible_text, canvas_content, canvas_kind)``.
|
||||||
|
``canvas_content`` and ``canvas_kind`` are ``None`` when no block is found.
|
||||||
|
"""
|
||||||
|
match = _CANVAS_BLOCK_RE.search(text)
|
||||||
|
if not match:
|
||||||
|
return text, None, None
|
||||||
|
|
||||||
|
canvas_kind = match.group(1).strip()
|
||||||
|
canvas_content = match.group(2).strip()
|
||||||
|
visible = text[: match.start()] + text[match.end() :]
|
||||||
|
visible = visible.strip()
|
||||||
|
return visible, canvas_content, canvas_kind
|
||||||
|
|
||||||
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd
|
||||||
|
|
||||||
|
|
||||||
class StreamFormatter:
|
class StreamFormatter:
|
||||||
@@ -23,14 +47,6 @@ class StreamFormatter:
|
|||||||
started = False
|
started = False
|
||||||
|
|
||||||
async for event_type, data in event_stream:
|
async for event_type, data in event_stream:
|
||||||
if event_type == "floating_domain":
|
|
||||||
if isinstance(data, dict):
|
|
||||||
yield WsFloatingDomain(
|
|
||||||
request_id=self.request_id,
|
|
||||||
domain=data,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if event_type != "token":
|
if event_type != "token":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ from app.core.llm import get_agent_llm, model_for_agent
|
|||||||
from app.core.preprocessors import detect_content_type, preprocess
|
from app.core.preprocessors import detect_content_type, preprocess
|
||||||
from app.core.ws_context import clear_client_executor, execute_on_client, set_client_executor
|
from app.core.ws_context import clear_client_executor, execute_on_client, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
from app.models import ScoutRunLog, CloudScoutConfig, LocalScoutConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -169,7 +169,7 @@ def _is_overdue(schedule_cron: str, last_run_at: datetime | None) -> bool:
|
|||||||
next_run: datetime = cron.get_next(datetime)
|
next_run: datetime = cron.get_next(datetime)
|
||||||
return now >= next_run
|
return now >= next_run
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("agent_runner: cannot parse cron %r: %s", schedule_cron, exc)
|
logger.warning("scout_runner: cannot parse cron %r: %s", schedule_cron, exc)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@@ -290,7 +290,7 @@ async def _run_agent_with_tools(
|
|||||||
call_name = str(call.get("name", ""))
|
call_name = str(call.get("name", ""))
|
||||||
call_args = call.get("args", {})
|
call_args = call.get("args", {})
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: tool_call name=%s args=%s",
|
"scout_runner: tool_call name=%s args=%s",
|
||||||
call_name,
|
call_name,
|
||||||
json.dumps(call_args, ensure_ascii=True)[:800],
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
)
|
)
|
||||||
@@ -305,7 +305,7 @@ async def _run_agent_with_tools(
|
|||||||
tool_output = await tool_fn.ainvoke(call_args)
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: tool_result name=%s output=%s",
|
"scout_runner: tool_result name=%s output=%s",
|
||||||
call_name,
|
call_name,
|
||||||
str(tool_output)[:200],
|
str(tool_output)[:200],
|
||||||
)
|
)
|
||||||
@@ -360,7 +360,7 @@ async def _scan_directories(
|
|||||||
try:
|
try:
|
||||||
result = await execute_on_client(action="list_directory", data={"path": path})
|
result = await execute_on_client(action="list_directory", data={"path": path})
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("agent_runner: list_directory failed %r: %s", path, exc)
|
logger.warning("scout_runner: list_directory failed %r: %s", path, exc)
|
||||||
return
|
return
|
||||||
for entry in result.get("entries", []):
|
for entry in result.get("entries", []):
|
||||||
entry_path = entry.get("path", "")
|
entry_path = entry.get("path", "")
|
||||||
@@ -414,7 +414,7 @@ async def _fetch_projects() -> list[dict]:
|
|||||||
result = await execute_on_client(action="select", table="projects")
|
result = await execute_on_client(action="select", table="projects")
|
||||||
return result.get("rows", [])
|
return result.get("rows", [])
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("agent_runner: failed to fetch projects: %s", exc)
|
logger.warning("scout_runner: failed to fetch projects: %s", exc)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@@ -442,7 +442,7 @@ async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]:
|
|||||||
)
|
)
|
||||||
return result.get("rows", [])
|
return result.get("rows", [])
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("agent_runner: failed to fetch %s: %s", domain, exc)
|
logger.warning("scout_runner: failed to fetch %s: %s", domain, exc)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@@ -555,8 +555,8 @@ def _get_no_match_behavior(agent_config: dict) -> str:
|
|||||||
|
|
||||||
async def run_local_agent(
|
async def run_local_agent(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
config: LocalAgentConfig,
|
config: LocalScoutConfig,
|
||||||
run_log: AgentRunLog,
|
run_log: ScoutRunLog,
|
||||||
device_mgr: DeviceConnectionManager,
|
device_mgr: DeviceConnectionManager,
|
||||||
run_context: dict | None = None,
|
run_context: dict | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -586,7 +586,7 @@ async def run_local_agent(
|
|||||||
|
|
||||||
if not is_online:
|
if not is_online:
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: skip run=%s — device %r offline for user=%s",
|
"scout_runner: skip run=%s — device %r offline for user=%s",
|
||||||
run_id,
|
run_id,
|
||||||
target_device_id or "<any>",
|
target_device_id or "<any>",
|
||||||
user_id,
|
user_id,
|
||||||
@@ -605,7 +605,7 @@ async def run_local_agent(
|
|||||||
errors: list[str] = []
|
errors: list[str] = []
|
||||||
items_processed = 0
|
items_processed = 0
|
||||||
items_created = 0
|
items_created = 0
|
||||||
agent_config: dict = config.agent_config or {}
|
agent_config: dict = config.scout_config or {}
|
||||||
processing_tools = _build_processing_tools(config.data_types)
|
processing_tools = _build_processing_tools(config.data_types)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -616,7 +616,7 @@ async def run_local_agent(
|
|||||||
last_run_at=config.last_run_at,
|
last_run_at=config.last_run_at,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: run=%s found %d file(s) after filtering", run_id, len(file_paths)
|
"scout_runner: run=%s found %d file(s) after filtering", run_id, len(file_paths)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not file_paths:
|
if not file_paths:
|
||||||
@@ -641,7 +641,7 @@ async def run_local_agent(
|
|||||||
raw_content: str = file_result.get("content", "")
|
raw_content: str = file_result.get("content", "")
|
||||||
if not raw_content.strip():
|
if not raw_content.strip():
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"agent_runner: run=%s skipping empty file %r", run_id, file_path
|
"scout_runner: run=%s skipping empty file %r", run_id, file_path
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -651,16 +651,21 @@ async def run_local_agent(
|
|||||||
preprocessed = preprocess(content_type, raw_content)
|
preprocessed = preprocess(content_type, raw_content)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: run=%s file=%r content_type=%s clean_len=%d",
|
"scout_runner: run=%s file=%r content_type=%s clean_len=%d",
|
||||||
run_id, file_path, content_type, len(preprocessed.clean_text),
|
run_id, file_path, content_type, len(preprocessed.clean_text),
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── Phase B: single LLM call ─────────────────────────
|
# ── Phase B: single LLM call ─────────────────────────
|
||||||
extraction_rules = _get_extraction_rules(agent_config, content_type)
|
extraction_rules = _get_extraction_rules(agent_config, content_type)
|
||||||
no_match_behavior = _get_no_match_behavior(agent_config)
|
no_match_behavior = _get_no_match_behavior(agent_config)
|
||||||
global_rules_lines = "\n".join(
|
base_global_rules = list(agent_config.get("global_rules", []))
|
||||||
f"- {r}" for r in agent_config.get("global_rules", [])
|
if "notes" in config.data_types:
|
||||||
|
base_global_rules.append(
|
||||||
|
"For notes: when updating an existing note use `propose_note_edit` "
|
||||||
|
"(type=append/insert/replace) so the user can review AI changes. "
|
||||||
|
"Only call `update_note` for complete content replacement without review."
|
||||||
)
|
)
|
||||||
|
global_rules_lines = "\n".join(f"- {r}" for r in base_global_rules)
|
||||||
metadata_section = _format_metadata(preprocessed.metadata)
|
metadata_section = _format_metadata(preprocessed.metadata)
|
||||||
|
|
||||||
system_prompt = compile_prompt(
|
system_prompt = compile_prompt(
|
||||||
@@ -706,19 +711,19 @@ async def run_local_agent(
|
|||||||
projects_block = _format_projects(projects)
|
projects_block = _format_projects(projects)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: run=%s file=%r created=%d result=%s",
|
"scout_runner: run=%s file=%r created=%d result=%s",
|
||||||
run_id, file_path, file_created, result_text[:200],
|
run_id, file_path, file_created, result_text[:200],
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
errors.append(f"Error processing '{file_path}': {exc}")
|
errors.append(f"Error processing '{file_path}': {exc}")
|
||||||
logger.error(
|
logger.error(
|
||||||
"agent_runner: run=%s file=%r failed: %s", run_id, file_path, exc
|
"scout_runner: run=%s file=%r failed: %s", run_id, file_path, exc
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
errors.append(f"Agent run failed: {exc}")
|
errors.append(f"Agent run failed: {exc}")
|
||||||
logger.error("agent_runner: run=%s failed: %s", run_id, exc)
|
logger.error("scout_runner: run=%s failed: %s", run_id, exc)
|
||||||
finally:
|
finally:
|
||||||
_running_agents.discard(agent_id)
|
_running_agents.discard(agent_id)
|
||||||
clear_client_executor()
|
clear_client_executor()
|
||||||
@@ -739,7 +744,7 @@ async def run_local_agent(
|
|||||||
errors=errors,
|
errors=errors,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: run=%s done status=%s processed=%d created=%d errors=%d",
|
"scout_runner: run=%s done status=%s processed=%d created=%d errors=%d",
|
||||||
run_id,
|
run_id,
|
||||||
final_status,
|
final_status,
|
||||||
items_processed,
|
items_processed,
|
||||||
@@ -757,7 +762,7 @@ async def run_local_agent(
|
|||||||
})
|
})
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"agent_runner: run=%s failed to send run_complete: %s", run_id, exc
|
"scout_runner: run=%s failed to send run_complete: %s", run_id, exc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -768,8 +773,8 @@ _CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
|
|||||||
|
|
||||||
async def run_cloud_agent(
|
async def run_cloud_agent(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
config: CloudAgentConfig,
|
config: CloudScoutConfig,
|
||||||
run_log: AgentRunLog,
|
run_log: ScoutRunLog,
|
||||||
device_mgr: DeviceConnectionManager,
|
device_mgr: DeviceConnectionManager,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute a cloud connector agent run end-to-end.
|
"""Execute a cloud connector agent run end-to-end.
|
||||||
@@ -792,7 +797,7 @@ async def run_cloud_agent(
|
|||||||
# ── 1. Device online check ─────────────────────────────────────────
|
# ── 1. Device online check ─────────────────────────────────────────
|
||||||
if not device_mgr.is_online(user_id):
|
if not device_mgr.is_online(user_id):
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: skip cloud run=%s — no device online for user=%s",
|
"scout_runner: skip cloud run=%s — no device online for user=%s",
|
||||||
run_id,
|
run_id,
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
@@ -817,7 +822,7 @@ async def run_cloud_agent(
|
|||||||
try:
|
try:
|
||||||
credentials_info = decrypt_token(config.oauth_token_encrypted)
|
credentials_info = decrypt_token(config.oauth_token_encrypted)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
logger.error("agent_runner: failed to decrypt OAuth token for agent %s: %s", config.id, exc)
|
logger.error("scout_runner: failed to decrypt OAuth token for agent %s: %s", config.id, exc)
|
||||||
await _finalize_run(
|
await _finalize_run(
|
||||||
run_log,
|
run_log,
|
||||||
status="error",
|
status="error",
|
||||||
@@ -863,7 +868,7 @@ async def run_cloud_agent(
|
|||||||
raw_messages = []
|
raw_messages = []
|
||||||
except RuntimeError as exc:
|
except RuntimeError as exc:
|
||||||
logger.error(
|
logger.error(
|
||||||
"agent_runner: provider fetch failed for cloud agent %s: %s", config.id, exc
|
"scout_runner: provider fetch failed for cloud agent %s: %s", config.id, exc
|
||||||
)
|
)
|
||||||
await _finalize_run(
|
await _finalize_run(
|
||||||
run_log,
|
run_log,
|
||||||
@@ -876,7 +881,7 @@ async def run_cloud_agent(
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: cloud agent %s fetched %d item(s) from %s for user=%s",
|
"scout_runner: cloud agent %s fetched %d item(s) from %s for user=%s",
|
||||||
config.id,
|
config.id,
|
||||||
len(raw_messages),
|
len(raw_messages),
|
||||||
config.provider,
|
config.provider,
|
||||||
@@ -936,16 +941,16 @@ async def run_cloud_agent(
|
|||||||
new_encrypted = encrypt_token(refreshed)
|
new_encrypted = encrypt_token(refreshed)
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
cfg_result = await db.execute(
|
cfg_result = await db.execute(
|
||||||
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
|
select(CloudScoutConfig).where(CloudScoutConfig.id == config.id)
|
||||||
)
|
)
|
||||||
cfg_row = cfg_result.scalar_one_or_none()
|
cfg_row = cfg_result.scalar_one_or_none()
|
||||||
if cfg_row:
|
if cfg_row:
|
||||||
cfg_row.oauth_token_encrypted = new_encrypted
|
cfg_row.oauth_token_encrypted = new_encrypted
|
||||||
await db.commit()
|
await db.commit()
|
||||||
logger.debug("agent_runner: refreshed OAuth token persisted for agent %s", config.id)
|
logger.debug("scout_runner: refreshed OAuth token persisted for agent %s", config.id)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"agent_runner: failed to persist refreshed token for agent %s: %s",
|
"scout_runner: failed to persist refreshed token for agent %s: %s",
|
||||||
config.id,
|
config.id,
|
||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
@@ -969,7 +974,7 @@ async def run_cloud_agent(
|
|||||||
config_type="cloud",
|
config_type="cloud",
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: cloud run=%s done status=%s processed=%d created=%d errors=%d",
|
"scout_runner: cloud run=%s done status=%s processed=%d created=%d errors=%d",
|
||||||
run_id,
|
run_id,
|
||||||
final_status,
|
final_status,
|
||||||
items_processed,
|
items_processed,
|
||||||
@@ -991,7 +996,7 @@ async def trigger_pending_runs(
|
|||||||
Called as a background task from the device WS endpoint on ``device_hello``.
|
Called as a background task from the device WS endpoint on ``device_hello``.
|
||||||
"""
|
"""
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: pending-run scan skipped for user=%s device=%s (client-owned agent config)",
|
"scout_runner: pending-run scan skipped for user=%s device=%s (client-owned agent config)",
|
||||||
user_id,
|
user_id,
|
||||||
device_id,
|
device_id,
|
||||||
)
|
)
|
||||||
@@ -1002,7 +1007,7 @@ async def trigger_pending_runs(
|
|||||||
|
|
||||||
|
|
||||||
async def _finalize_run(
|
async def _finalize_run(
|
||||||
run_log: AgentRunLog,
|
run_log: ScoutRunLog,
|
||||||
*,
|
*,
|
||||||
status: str,
|
status: str,
|
||||||
items_processed: int = 0,
|
items_processed: int = 0,
|
||||||
@@ -1026,14 +1031,14 @@ async def _finalize_run(
|
|||||||
if update_config_last_run and config_id:
|
if update_config_last_run and config_id:
|
||||||
if config_type == "local":
|
if config_type == "local":
|
||||||
cfg_result = await db.execute(
|
cfg_result = await db.execute(
|
||||||
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
|
select(LocalScoutConfig).where(LocalScoutConfig.id == config_id)
|
||||||
)
|
)
|
||||||
cfg = cfg_result.scalar_one_or_none()
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
if cfg:
|
if cfg:
|
||||||
cfg.last_run_at = now
|
cfg.last_run_at = now
|
||||||
elif config_type == "cloud":
|
elif config_type == "cloud":
|
||||||
cfg_result = await db.execute(
|
cfg_result = await db.execute(
|
||||||
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
select(CloudScoutConfig).where(CloudScoutConfig.id == config_id)
|
||||||
)
|
)
|
||||||
cfg = cfg_result.scalar_one_or_none()
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
if cfg:
|
if cfg:
|
||||||
@@ -1042,5 +1047,5 @@ async def _finalize_run(
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error(
|
logger.error(
|
||||||
"agent_runner: failed to finalize run_log=%s: %s", run_log.id, exc
|
"scout_runner: failed to finalize run_log=%s: %s", run_log.id, exc
|
||||||
)
|
)
|
||||||
@@ -54,6 +54,43 @@ class _SessionBuffer:
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
self._store.pop((user_id, session_id), None)
|
self._store.pop((user_id, session_id), None)
|
||||||
|
|
||||||
|
def append_system_message(self, user_id: str, session_id: str, text: str) -> None:
|
||||||
|
"""Append a synthetic system message to the buffer for the given session.
|
||||||
|
|
||||||
|
Creates the session slot if it does not yet exist. Used by the
|
||||||
|
contextual_scope_update handler to inject navigation events without
|
||||||
|
making an LLM call.
|
||||||
|
"""
|
||||||
|
from langchain_core.messages import SystemMessage # noqa: PLC0415
|
||||||
|
|
||||||
|
key = (user_id, session_id)
|
||||||
|
with self._lock:
|
||||||
|
entry = self._store.get(key)
|
||||||
|
if entry is None:
|
||||||
|
msgs: list[BaseMessage] = [SystemMessage(content=text)]
|
||||||
|
else:
|
||||||
|
_, existing = entry
|
||||||
|
msgs = list(existing) + [SystemMessage(content=text)]
|
||||||
|
capped = msgs[-MAX_MESSAGES_PER_SESSION:]
|
||||||
|
self._store[key] = (time.monotonic(), capped)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextualBufferProxy:
|
||||||
|
"""Thin wrapper around _SessionBuffer that closes over user_id + session_id.
|
||||||
|
|
||||||
|
Returned by get_session_buffer() so callers can call
|
||||||
|
``proxy.append_system_message(text)`` without threading user_id/session_id
|
||||||
|
through every call site.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, buf: "_SessionBuffer", user_id: str, session_id: str) -> None:
|
||||||
|
self._buf = buf
|
||||||
|
self._user_id = user_id
|
||||||
|
self._session_id = session_id
|
||||||
|
|
||||||
|
def append_system_message(self, text: str) -> None:
|
||||||
|
self._buf.append_system_message(self._user_id, self._session_id, text)
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton — same pattern as _pending_states in api/app/api/routes/auth.py
|
# Module-level singleton — same pattern as _pending_states in api/app/api/routes/auth.py
|
||||||
session_buffer = _SessionBuffer()
|
session_buffer = _SessionBuffer()
|
||||||
@@ -8,7 +8,7 @@ blocking the event loop.
|
|||||||
Token refresh is handled transparently: when the stored access token has
|
Token refresh is handled transparently: when the stored access token has
|
||||||
expired, ``google.auth.transport.requests.Request`` will use the refresh
|
expired, ``google.auth.transport.requests.Request`` will use the refresh
|
||||||
token to obtain a fresh one. The caller is responsible for persisting
|
token to obtain a fresh one. The caller is responsible for persisting
|
||||||
any refreshed credentials back to ``CloudAgentConfig.oauth_token_encrypted``
|
any refreshed credentials back to ``CloudScoutConfig.oauth_token_encrypted``
|
||||||
(see ``agent_runner.run_cloud_agent``).
|
(see ``agent_runner.run_cloud_agent``).
|
||||||
|
|
||||||
Credential dict shape (Google OAuth2):
|
Credential dict shape (Google OAuth2):
|
||||||
|
|||||||
103
app/main.py
103
app/main.py
@@ -77,8 +77,98 @@ async def _memory_cron_tick() -> None:
|
|||||||
_log.warning("memory cron tick: failed: %s", exc)
|
_log.warning("memory cron tick: failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _scout_cron_tick() -> None:
|
||||||
|
"""Every-15-min cron: poll enabled cloud scouts (cron-fallback; push is primary).
|
||||||
|
|
||||||
|
Skips any scout whose ``last_run_at`` is within the last 5 minutes so
|
||||||
|
a push notification and the fallback cron don't double-fire within the
|
||||||
|
same window.
|
||||||
|
"""
|
||||||
|
import logging # noqa: PLC0415
|
||||||
|
import uuid # noqa: PLC0415
|
||||||
|
from datetime import datetime, timezone # noqa: PLC0415
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
_log.info("scout cron tick: starting")
|
||||||
|
try:
|
||||||
|
from app.db import async_session # noqa: PLC0415
|
||||||
|
from app.models import CloudScoutConfig # noqa: PLC0415
|
||||||
|
from app.scouts.engine import ScoutEngine # noqa: PLC0415
|
||||||
|
from sqlalchemy import select # noqa: PLC0415
|
||||||
|
|
||||||
|
async with async_session() as session:
|
||||||
|
scouts = (await session.execute(
|
||||||
|
select(CloudScoutConfig).where(CloudScoutConfig.enabled == True) # noqa: E712
|
||||||
|
)).scalars().all()
|
||||||
|
|
||||||
|
engine = ScoutEngine()
|
||||||
|
triggered = 0
|
||||||
|
for scout in scouts:
|
||||||
|
# Rate-limit guard: push is primary; skip if ran within 5 minutes.
|
||||||
|
if scout.last_run_at:
|
||||||
|
elapsed = (datetime.now(tz=timezone.utc) - scout.last_run_at).total_seconds()
|
||||||
|
if elapsed < 300:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await engine.trigger_scout(uuid.UUID(str(scout.id)))
|
||||||
|
triggered += 1
|
||||||
|
except Exception as exc:
|
||||||
|
_log.warning("scout cron tick: trigger failed scout=%s: %s", scout.id, exc)
|
||||||
|
|
||||||
|
_log.info("scout cron tick: done triggered=%d total=%d", triggered, len(scouts))
|
||||||
|
except Exception as exc:
|
||||||
|
_log.warning("scout cron tick: failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
async def _scout_watch_renewal_tick() -> None:
|
||||||
|
"""Every-24-hour cron: re-issue Gmail users.watch for scouts expiring within 24h.
|
||||||
|
|
||||||
|
Handles missing or misconfigured connectors gracefully — logs and continues.
|
||||||
|
"""
|
||||||
|
import logging # noqa: PLC0415
|
||||||
|
from datetime import datetime, timedelta, timezone # noqa: PLC0415
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
_log.info("scout watch renewal tick: starting")
|
||||||
|
try:
|
||||||
|
from app.db import async_session # noqa: PLC0415
|
||||||
|
from app.models import CloudScoutConfig # noqa: PLC0415
|
||||||
|
from app.scouts.connectors.registry import get_connector # noqa: PLC0415
|
||||||
|
from sqlalchemy import select # noqa: PLC0415
|
||||||
|
|
||||||
|
threshold = datetime.now(tz=timezone.utc) + timedelta(hours=24)
|
||||||
|
renewed = 0
|
||||||
|
async with async_session() as session:
|
||||||
|
scouts = (await session.execute(
|
||||||
|
select(CloudScoutConfig).where(
|
||||||
|
CloudScoutConfig.enabled == True, # noqa: E712
|
||||||
|
CloudScoutConfig.provider == "gmail",
|
||||||
|
CloudScoutConfig.gmail_watch_expires_at <= threshold,
|
||||||
|
)
|
||||||
|
)).scalars().all()
|
||||||
|
|
||||||
|
for scout in scouts:
|
||||||
|
try:
|
||||||
|
connector = get_connector("gmail")
|
||||||
|
await connector.renew_watch(scout)
|
||||||
|
renewed += 1
|
||||||
|
except Exception:
|
||||||
|
_log.exception("scout watch renewal tick: renew failed scout=%s", scout.id)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
_log.info("scout watch renewal tick: done renewed=%d", renewed)
|
||||||
|
except Exception as exc:
|
||||||
|
_log.warning("scout watch renewal tick: failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
# Startup: register source connectors.
|
||||||
|
from app.scouts.connectors.gmail import GmailConnector # noqa: PLC0415
|
||||||
|
from app.scouts.connectors.registry import register_connector # noqa: PLC0415
|
||||||
|
register_connector(GmailConnector())
|
||||||
|
|
||||||
# Startup: ensure agent tool modules are loaded.
|
# Startup: ensure agent tool modules are loaded.
|
||||||
import app.agents # noqa: F401
|
import app.agents # noqa: F401
|
||||||
|
|
||||||
@@ -89,6 +179,14 @@ async def lifespan(app: FastAPI):
|
|||||||
scheduler = AsyncIOScheduler()
|
scheduler = AsyncIOScheduler()
|
||||||
scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron")
|
scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron")
|
||||||
scheduler.add_job(_memory_audit_cron_tick, "interval", weeks=1, id="memory_audit_cron")
|
scheduler.add_job(_memory_audit_cron_tick, "interval", weeks=1, id="memory_audit_cron")
|
||||||
|
scheduler.add_job(
|
||||||
|
_scout_cron_tick, "interval", minutes=15,
|
||||||
|
id="scout_cron_tick", replace_existing=True,
|
||||||
|
)
|
||||||
|
scheduler.add_job(
|
||||||
|
_scout_watch_renewal_tick, "interval", hours=24,
|
||||||
|
id="scout_watch_renewal_tick", replace_existing=True,
|
||||||
|
)
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)")
|
logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)")
|
||||||
|
|
||||||
@@ -124,12 +222,13 @@ def create_app() -> FastAPI:
|
|||||||
app.add_middleware(SanitizerMiddleware)
|
app.add_middleware(SanitizerMiddleware)
|
||||||
app.add_middleware(TierRateLimitMiddleware)
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
from app.api.routes import agents, auth, billing, chat, device_ws, memory
|
from app.api.routes import scouts, auth, billing, chat, device_ws, memory, scout_webhooks
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api/v1")
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
app.include_router(chat.router, prefix="/api/v1")
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
app.include_router(billing.router, prefix="/api/v1")
|
app.include_router(billing.router, prefix="/api/v1")
|
||||||
app.include_router(agents.router, prefix="/api/v1")
|
app.include_router(scouts.router, prefix="/api/v1")
|
||||||
|
app.include_router(scout_webhooks.router, prefix="/api/v1")
|
||||||
app.include_router(device_ws.router, prefix="/api/v1")
|
app.include_router(device_ws.router, prefix="/api/v1")
|
||||||
app.include_router(memory.router, prefix="/api/v1")
|
app.include_router(memory.router, prefix="/api/v1")
|
||||||
|
|
||||||
|
|||||||
104
app/models.py
104
app/models.py
@@ -1,15 +1,15 @@
|
|||||||
"""SQLAlchemy ORM models for all persistent tables.
|
"""SQLAlchemy ORM models for all persistent tables.
|
||||||
|
|
||||||
Only auth, billing, agent config, and memory data live here.
|
Only auth, billing, scout config, and memory data live here.
|
||||||
User content (notes, tasks, etc.) lives exclusively on the client.
|
User content (notes, tasks, etc.) lives exclusively on the client.
|
||||||
|
|
||||||
Table inventory:
|
Table inventory:
|
||||||
users — account credentials + tier
|
users — account credentials + tier
|
||||||
refresh_tokens — hashed refresh token store
|
refresh_tokens — hashed refresh token store
|
||||||
subscriptions — Stripe subscription records
|
subscriptions — Stripe subscription records
|
||||||
local_agent_configs — per-device batch agent configs
|
local_scout_configs — per-device batch scout configs
|
||||||
cloud_agent_configs — OAuth-backed cloud agent configs
|
cloud_scout_configs — OAuth-backed cloud scout configs
|
||||||
agent_run_logs — execution history for all agents
|
scout_run_logs — execution history for all scouts
|
||||||
memory_core — per-user persistent key/value preferences (encrypted)
|
memory_core — per-user persistent key/value preferences (encrypted)
|
||||||
memory_associative — per-user semantic memory with embeddings (encrypted)
|
memory_associative — per-user semantic memory with embeddings (encrypted)
|
||||||
memory_episodic — per-user session summaries (encrypted)
|
memory_episodic — per-user session summaries (encrypted)
|
||||||
@@ -34,8 +34,10 @@ from sqlalchemy import (
|
|||||||
LargeBinary,
|
LargeBinary,
|
||||||
String,
|
String,
|
||||||
Text,
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
Uuid,
|
Uuid,
|
||||||
func,
|
func,
|
||||||
|
text,
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
@@ -158,8 +160,8 @@ class Subscription(Base):
|
|||||||
user: Mapped[User] = relationship(back_populates="subscription")
|
user: Mapped[User] = relationship(back_populates="subscription")
|
||||||
|
|
||||||
|
|
||||||
class LocalAgentConfig(Base):
|
class LocalScoutConfig(Base):
|
||||||
__tablename__ = "local_agent_configs"
|
__tablename__ = "local_scout_configs"
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
id: Mapped[str] = mapped_column(
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
@@ -172,7 +174,7 @@ class LocalAgentConfig(Base):
|
|||||||
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
agent_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
scout_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||||
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
||||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||||
@@ -184,17 +186,17 @@ class LocalAgentConfig(Base):
|
|||||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
)
|
)
|
||||||
|
|
||||||
run_logs: Mapped[list[AgentRunLog]] = relationship(
|
run_logs: Mapped[list["ScoutRunLog"]] = relationship(
|
||||||
back_populates="local_agent",
|
back_populates="local_scout",
|
||||||
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
|
primaryjoin="and_(ScoutRunLog.scout_id == LocalScoutConfig.id, ScoutRunLog.scout_type == 'local')",
|
||||||
foreign_keys="AgentRunLog.agent_id",
|
foreign_keys="ScoutRunLog.scout_id",
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
overlaps="run_logs,cloud_agent",
|
overlaps="run_logs,cloud_scout",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CloudAgentConfig(Base):
|
class CloudScoutConfig(Base):
|
||||||
__tablename__ = "cloud_agent_configs"
|
__tablename__ = "cloud_scout_configs"
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
id: Mapped[str] = mapped_column(
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
@@ -217,52 +219,88 @@ class CloudAgentConfig(Base):
|
|||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
)
|
)
|
||||||
|
auto_trash_spam: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default=text("false"))
|
||||||
|
gmail_history_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||||
|
gmail_watch_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
device_inactivity_pause_days: Mapped[int] = mapped_column(Integer, nullable=False, default=14, server_default="14")
|
||||||
|
|
||||||
run_logs: Mapped[list[AgentRunLog]] = relationship(
|
run_logs: Mapped[list["ScoutRunLog"]] = relationship(
|
||||||
back_populates="cloud_agent",
|
back_populates="cloud_scout",
|
||||||
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
|
primaryjoin="and_(ScoutRunLog.scout_id == CloudScoutConfig.id, ScoutRunLog.scout_type == 'cloud')",
|
||||||
foreign_keys="AgentRunLog.agent_id",
|
foreign_keys="ScoutRunLog.scout_id",
|
||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
overlaps="run_logs,local_agent",
|
overlaps="run_logs,local_scout",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AgentRunLog(Base):
|
class ScoutTriageQueue(Base):
|
||||||
__tablename__ = "agent_run_logs"
|
__tablename__ = "scout_triage_queue"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("scout_id", "source_msg_ref", name="uq_scout_triage_queue_scout_msg"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||||
|
scout_id: Mapped[str] = mapped_column(Uuid(as_uuid=False), ForeignKey("cloud_scout_configs.id", ondelete="CASCADE"), nullable=False)
|
||||||
|
source_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
|
source_msg_ref: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
triage_verdict: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||||
|
triage_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
status: Mapped[str] = mapped_column(String(20), nullable=False, default="queued", server_default="queued")
|
||||||
|
triaged_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||||
|
delivered_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
acked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class ScoutRunLog(Base):
|
||||||
|
__tablename__ = "scout_run_logs"
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
id: Mapped[str] = mapped_column(
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
)
|
)
|
||||||
# Plain string — not a FK because it references either local_agent_configs or cloud_agent_configs
|
# Plain string — not a FK because it references either local_scout_configs or cloud_scout_configs
|
||||||
# depending on agent_type. Query by (agent_id, agent_type) to locate the source config.
|
# depending on scout_type. Query by (scout_id, scout_type) to locate the source config.
|
||||||
agent_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
scout_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||||
agent_type: Mapped[str] = mapped_column(AgentTypeEnum, nullable=False)
|
scout_type: Mapped[str] = mapped_column(AgentTypeEnum, nullable=False)
|
||||||
user_id: Mapped[str] = mapped_column(
|
user_id: Mapped[str] = mapped_column(
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
)
|
)
|
||||||
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
|
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
|
||||||
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
|
||||||
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||||
started_at: Mapped[datetime] = mapped_column(
|
started_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
)
|
)
|
||||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
local_agent: Mapped[LocalAgentConfig | None] = relationship(
|
local_scout: Mapped["LocalScoutConfig | None"] = relationship(
|
||||||
back_populates="run_logs",
|
back_populates="run_logs",
|
||||||
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
|
primaryjoin="and_(ScoutRunLog.scout_id == LocalScoutConfig.id, ScoutRunLog.scout_type == 'local')",
|
||||||
foreign_keys="AgentRunLog.agent_id",
|
foreign_keys="ScoutRunLog.scout_id",
|
||||||
overlaps="run_logs,cloud_agent",
|
overlaps="run_logs,cloud_scout",
|
||||||
)
|
)
|
||||||
cloud_agent: Mapped[CloudAgentConfig | None] = relationship(
|
cloud_scout: Mapped["CloudScoutConfig | None"] = relationship(
|
||||||
back_populates="run_logs",
|
back_populates="run_logs",
|
||||||
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
|
primaryjoin="and_(ScoutRunLog.scout_id == CloudScoutConfig.id, ScoutRunLog.scout_type == 'cloud')",
|
||||||
foreign_keys="AgentRunLog.agent_id",
|
foreign_keys="ScoutRunLog.scout_id",
|
||||||
overlaps="run_logs,local_agent",
|
overlaps="run_logs,local_scout",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MonthlyTokenUsage(Base):
|
||||||
|
__tablename__ = "monthly_token_usage"
|
||||||
|
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
)
|
||||||
|
year_month: Mapped[str] = mapped_column(String(7), primary_key=True) # 'YYYY-MM'
|
||||||
|
feature: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
|
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
|
||||||
|
|
||||||
|
|
||||||
# ── Memory models ─────────────────────────────────────────────────────────────
|
# ── Memory models ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -73,11 +73,9 @@ class WsFrameType(str, Enum):
|
|||||||
device_hello = "device_hello"
|
device_hello = "device_hello"
|
||||||
# ── v3 frame types ─────────────────────────────────────────────────
|
# ── v3 frame types ─────────────────────────────────────────────────
|
||||||
home_request = "home_request"
|
home_request = "home_request"
|
||||||
floating_request = "floating_request"
|
|
||||||
stream_start = "stream_start"
|
stream_start = "stream_start"
|
||||||
stream_text = "stream_text"
|
stream_text = "stream_text"
|
||||||
stream_end = "stream_end"
|
stream_end = "stream_end"
|
||||||
floating_domain = "floating_domain"
|
|
||||||
data_request = "data_request"
|
data_request = "data_request"
|
||||||
data_response = "data_response"
|
data_response = "data_response"
|
||||||
mutation = "mutation"
|
mutation = "mutation"
|
||||||
@@ -87,6 +85,22 @@ class WsFrameType(str, Enum):
|
|||||||
journey_reply = "journey_reply"
|
journey_reply = "journey_reply"
|
||||||
# ── v5 brief frame types ──────────────────────────────────────────
|
# ── v5 brief frame types ──────────────────────────────────────────
|
||||||
brief_request = "brief_request"
|
brief_request = "brief_request"
|
||||||
|
# ── v6 task brief frame types ─────────────────────────────────────
|
||||||
|
task_brief_request = "task_brief_request"
|
||||||
|
# ── v7 folder index frame types ───────────────────────────────────
|
||||||
|
index_session_start = "index_session_start"
|
||||||
|
index_file_batch = "index_file_batch"
|
||||||
|
index_session_cancel = "index_session_cancel"
|
||||||
|
index_file_result = "index_file_result"
|
||||||
|
index_session_progress = "index_session_progress"
|
||||||
|
index_session_done = "index_session_done"
|
||||||
|
# ── v8 contextual sidebar frame types ────────────────────────────
|
||||||
|
contextual_request = "contextual_request"
|
||||||
|
contextual_scope_update = "contextual_scope_update"
|
||||||
|
contextual_scope_ack = "contextual_scope_ack"
|
||||||
|
# ── v9 scout proposal frame types ────────────────────────────────
|
||||||
|
SCOUT_PROPOSAL = "scout_proposal"
|
||||||
|
SCOUT_PROPOSAL_ACK = "scout_proposal_ack"
|
||||||
|
|
||||||
|
|
||||||
class WsToolCall(BaseModel):
|
class WsToolCall(BaseModel):
|
||||||
@@ -136,7 +150,7 @@ class WsDeviceHello(BaseModel):
|
|||||||
|
|
||||||
type: Literal[WsFrameType.device_hello] = WsFrameType.device_hello
|
type: Literal[WsFrameType.device_hello] = WsFrameType.device_hello
|
||||||
device_id: str
|
device_id: str
|
||||||
agent_ids: list[str] = Field(default_factory=list)
|
scout_ids: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -152,13 +166,6 @@ class FormatPrefsModel(BaseModel):
|
|||||||
now_iso: str = ""
|
now_iso: str = ""
|
||||||
|
|
||||||
|
|
||||||
class WsFloatingScope(BaseModel):
|
|
||||||
"""Scope for a floating request — narrows the agent to a specific entity."""
|
|
||||||
|
|
||||||
type: Literal["task", "project", "note", "timeline"]
|
|
||||||
id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class WsHomeRequest(BaseModel):
|
class WsHomeRequest(BaseModel):
|
||||||
"""Client → Server: Home chat message."""
|
"""Client → Server: Home chat message."""
|
||||||
|
|
||||||
@@ -168,15 +175,6 @@ class WsHomeRequest(BaseModel):
|
|||||||
format_prefs: FormatPrefsModel | None = None
|
format_prefs: FormatPrefsModel | None = None
|
||||||
|
|
||||||
|
|
||||||
class WsFloatingRequest(BaseModel):
|
|
||||||
"""Client → Server: Floating chat message scoped to an entity."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
|
|
||||||
message: str
|
|
||||||
scope: WsFloatingScope
|
|
||||||
format_prefs: FormatPrefsModel | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class WsBriefRequest(BaseModel):
|
class WsBriefRequest(BaseModel):
|
||||||
"""Client → Server: Request a plain-text brief (home or project)."""
|
"""Client → Server: Request a plain-text brief (home or project)."""
|
||||||
|
|
||||||
@@ -209,28 +207,13 @@ class WsStreamEnd(BaseModel):
|
|||||||
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||||
request_id: str
|
request_id: str
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
|
mutations: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
|
|
||||||
class WsDomain(BaseModel):
|
# ── Scout Config V2 ───────────────────────────────────────────────────
|
||||||
"""Structured floating domain payload for UI routing decisions."""
|
|
||||||
|
|
||||||
type: Literal["task", "timeline", "project", "node"]
|
|
||||||
id: str | None = None
|
|
||||||
section: Literal["task", "timeline", "note"] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class WsFloatingDomain(BaseModel):
|
class ScoutContentTypeConfig(BaseModel):
|
||||||
"""Server → Client: domain determined for a floating request."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
|
||||||
request_id: str
|
|
||||||
domain: WsDomain
|
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Config V2 ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class ContentTypeConfig(BaseModel):
|
|
||||||
"""Per-type extraction config produced by the journey chatbot."""
|
"""Per-type extraction config produced by the journey chatbot."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
@@ -240,34 +223,34 @@ class ContentTypeConfig(BaseModel):
|
|||||||
extraction_prompt: str
|
extraction_prompt: str
|
||||||
|
|
||||||
|
|
||||||
class AgentConfig(BaseModel):
|
class ScoutConfig(BaseModel):
|
||||||
"""Structured agent configuration (replaces freeform prompt_template)."""
|
"""Structured scout configuration (replaces freeform prompt_template)."""
|
||||||
|
|
||||||
content_types: list[ContentTypeConfig] = []
|
content_types: list[ScoutContentTypeConfig] = []
|
||||||
global_rules: list[str] = []
|
global_rules: list[str] = []
|
||||||
data_types: list[str] = []
|
data_types: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Catalog ─────────────────────────────────────────────────────
|
# ── Scout Catalog ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
class AgentCatalogItem(BaseModel):
|
class ScoutCatalogItem(BaseModel):
|
||||||
type: str
|
type: str
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
|
|
||||||
|
|
||||||
class AgentCreationCheckRequest(BaseModel):
|
class ScoutCreationCheckRequest(BaseModel):
|
||||||
active_agents: int = Field(ge=0, default=0)
|
active_agents: int = Field(ge=0, default=0)
|
||||||
|
|
||||||
|
|
||||||
class AgentCreationCheckResponse(BaseModel):
|
class ScoutCreationCheckResponse(BaseModel):
|
||||||
allowed: bool
|
allowed: bool
|
||||||
tier: BillingTier
|
tier: BillingTier
|
||||||
active_agents: int
|
active_agents: int
|
||||||
limit: int
|
limit: int
|
||||||
|
|
||||||
|
|
||||||
class AgentTriggerRequest(BaseModel):
|
class ScoutTriggerRequest(BaseModel):
|
||||||
directory: str = Field(min_length=1)
|
directory: str = Field(min_length=1)
|
||||||
device_id: str = Field(default="")
|
device_id: str = Field(default="")
|
||||||
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
|
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
|
||||||
@@ -279,9 +262,9 @@ class AgentTriggerRequest(BaseModel):
|
|||||||
last_run_at: int | None = None # epoch ms from FE — enables incremental scanning
|
last_run_at: int | None = None # epoch ms from FE — enables incremental scanning
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Run Log ─────────────────────────────────────────────────────
|
# ── Scout Run Log ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
class AgentRunLogResponse(BaseModel):
|
class ScoutRunLogResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
agent_id: str
|
agent_id: str
|
||||||
agent_type: Literal["local", "cloud"]
|
agent_type: Literal["local", "cloud"]
|
||||||
@@ -295,3 +278,25 @@ class AgentRunLogResponse(BaseModel):
|
|||||||
|
|
||||||
# ── Chatbot Journey ───────────────────────────────────────────────────
|
# ── Chatbot Journey ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
# ── Scout Proposal Frame Models ───────────────────────────────────────
|
||||||
|
|
||||||
|
class ScoutProposalPayload(BaseModel):
|
||||||
|
id: str
|
||||||
|
scout_id: str
|
||||||
|
source_type: str
|
||||||
|
source_msg_ref: str
|
||||||
|
raw_subject: str | None = None
|
||||||
|
raw_snippet: str | None = None
|
||||||
|
category: Literal["unprocessed"] = "unprocessed"
|
||||||
|
payload: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ScoutProposalFrame(BaseModel):
|
||||||
|
type: Literal[WsFrameType.SCOUT_PROPOSAL]
|
||||||
|
proposal: ScoutProposalPayload
|
||||||
|
|
||||||
|
|
||||||
|
class ScoutProposalAckFrame(BaseModel):
|
||||||
|
type: Literal[WsFrameType.SCOUT_PROPOSAL_ACK]
|
||||||
|
proposal_id: str
|
||||||
73
app/schemas/contextual.py
Normal file
73
app/schemas/contextual.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""Contextual sidebar scope schema and prompt block renderer.
|
||||||
|
|
||||||
|
ContextualScope mirrors the TypeScript ContextualScope type sent by the
|
||||||
|
Electron renderer when the user opens the side chat anchored to a specific
|
||||||
|
view. The renderer ships camelCase keys; Pydantic's alias_generator maps
|
||||||
|
them to snake_case Python attributes automatically.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
from pydantic.alias_generators import to_camel
|
||||||
|
|
||||||
|
|
||||||
|
PageType = Literal[
|
||||||
|
"timeline",
|
||||||
|
"tasks",
|
||||||
|
"projects-list",
|
||||||
|
"project",
|
||||||
|
"note",
|
||||||
|
]
|
||||||
|
|
||||||
|
EntityType = Literal["project", "note", "task", "timeline_event"]
|
||||||
|
|
||||||
|
|
||||||
|
class ContextualScope(BaseModel):
|
||||||
|
"""Scope payload sent by the Electron renderer for contextual chat.
|
||||||
|
|
||||||
|
The renderer ships camelCase keys (entityType, entityId, ...). Pydantic's
|
||||||
|
alias generator maps them to snake_case Python attrs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
|
||||||
|
|
||||||
|
page: PageType
|
||||||
|
entity_type: Optional[EntityType] = None
|
||||||
|
entity_id: Optional[str] = None
|
||||||
|
entity_name: Optional[str] = None
|
||||||
|
project_id: Optional[str] = None
|
||||||
|
char_count: Optional[int] = None
|
||||||
|
counts: Optional[dict[str, int]] = None
|
||||||
|
filters: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
def render_scope_block(scope: ContextualScope) -> str:
|
||||||
|
"""Produce a single-paragraph human-readable summary of the current view
|
||||||
|
for injection into the contextual agent system prompt.
|
||||||
|
|
||||||
|
Never emits internal ids — only names. The LLM is told to use names in
|
||||||
|
prose; ids travel through tool calls.
|
||||||
|
"""
|
||||||
|
if scope.entity_type == "project":
|
||||||
|
c = scope.counts or {}
|
||||||
|
return (
|
||||||
|
f"User is viewing the project {scope.entity_name!r}. "
|
||||||
|
f"{c.get('tasks', 0)} tasks, "
|
||||||
|
f"{c.get('notes', 0)} notes, "
|
||||||
|
f"{c.get('milestones', 0)} milestones."
|
||||||
|
)
|
||||||
|
if scope.entity_type == "note":
|
||||||
|
return (
|
||||||
|
f"User is viewing the note {scope.entity_name!r} "
|
||||||
|
f"({scope.char_count or 0} characters)."
|
||||||
|
)
|
||||||
|
if scope.page == "tasks":
|
||||||
|
return "User is viewing the global Tasks list (all projects)."
|
||||||
|
if scope.page == "timeline":
|
||||||
|
return "User is viewing the global Timeline view."
|
||||||
|
if scope.page == "projects-list":
|
||||||
|
return "User is viewing the Projects list."
|
||||||
|
return f"User is on page {scope.page}."
|
||||||
0
app/scouts/__init__.py
Normal file
0
app/scouts/__init__.py
Normal file
0
app/scouts/connectors/__init__.py
Normal file
0
app/scouts/connectors/__init__.py
Normal file
56
app/scouts/connectors/base.py
Normal file
56
app/scouts/connectors/base.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""Source connector Protocol and shared item types.
|
||||||
|
|
||||||
|
A SourceConnector adapts a third-party data source (Gmail, Slack, ...) to the
|
||||||
|
shared ScoutEngine interface. Each connector owns:
|
||||||
|
|
||||||
|
* how to enumerate new items since the last poll (``list_new``)
|
||||||
|
* how to fetch a single item's metadata cheaply (``fetch_metadata``)
|
||||||
|
* how to fetch a single item's full content for in-memory triage
|
||||||
|
(``fetch_content``) — this content MUST NOT be persisted by the engine
|
||||||
|
* how to archive/trash an item (``archive``) for spam handling
|
||||||
|
* optional push-notification setup (``setup_watch`` / ``renew_watch``)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Literal, Protocol
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ItemRef(BaseModel):
|
||||||
|
source_msg_ref: str
|
||||||
|
received_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ItemMetadata(BaseModel):
|
||||||
|
subject: str | None = None
|
||||||
|
sender: str | None = None
|
||||||
|
snippet: str | None = None
|
||||||
|
received_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ItemContent(BaseModel):
|
||||||
|
metadata: ItemMetadata
|
||||||
|
body_text: str
|
||||||
|
raw_headers: dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class TriageVerdict(BaseModel):
|
||||||
|
verdict: Literal["relevant", "spam"]
|
||||||
|
reason: str
|
||||||
|
confidence: float = Field(ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class SourceConnector(Protocol):
|
||||||
|
"""Adapter for a third-party data source (Gmail, Slack, ...)."""
|
||||||
|
|
||||||
|
source_type: str # e.g. "gmail"
|
||||||
|
|
||||||
|
async def list_new(self, scout) -> list[ItemRef]: ...
|
||||||
|
async def fetch_metadata(self, scout, ref: ItemRef) -> ItemMetadata: ...
|
||||||
|
async def fetch_content(self, scout, ref: ItemRef) -> ItemContent: ...
|
||||||
|
async def archive(self, scout, ref: ItemRef) -> None: ...
|
||||||
|
async def setup_watch(self, scout) -> None: ...
|
||||||
|
async def renew_watch(self, scout) -> None: ...
|
||||||
213
app/scouts/connectors/gmail.py
Normal file
213
app/scouts/connectors/gmail.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
"""Gmail SourceConnector — wraps the existing GmailClient.
|
||||||
|
|
||||||
|
Responsibilities:
|
||||||
|
* list_new: incremental fetch since the scout's stored gmail_history_id
|
||||||
|
* fetch_metadata: subject + sender + snippet only (Gmail metadata format)
|
||||||
|
* fetch_content: full body text — transient, never persisted by engine
|
||||||
|
* archive: move a message to Gmail Trash (recoverable for 30 days)
|
||||||
|
* setup_watch / renew_watch: Gmail push notifications via Pub/Sub
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.integrations import decrypt_token
|
||||||
|
from app.scouts.connectors.base import ItemContent, ItemMetadata, ItemRef
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_plain_text_body(payload: dict) -> str:
|
||||||
|
"""Recursively walk a Gmail message payload to find text/plain content."""
|
||||||
|
import base64
|
||||||
|
mime_type = payload.get("mimeType", "")
|
||||||
|
if mime_type == "text/plain":
|
||||||
|
data = payload.get("body", {}).get("data", "")
|
||||||
|
if data:
|
||||||
|
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||||
|
return ""
|
||||||
|
if mime_type.startswith("multipart/"):
|
||||||
|
for part in payload.get("parts", []):
|
||||||
|
text = _extract_plain_text_body(part)
|
||||||
|
if text:
|
||||||
|
return text
|
||||||
|
# text/html fallback: strip tags rudimentarily if no text/plain part
|
||||||
|
if mime_type == "text/html":
|
||||||
|
data = payload.get("body", {}).get("data", "")
|
||||||
|
if data:
|
||||||
|
import re
|
||||||
|
html = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||||
|
return re.sub(r"<[^>]+>", " ", html)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _get_gmail_service(scout):
|
||||||
|
"""Return a synchronous Google API client for low-level metadata/history calls."""
|
||||||
|
from googleapiclient.discovery import build
|
||||||
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
|
creds_info = decrypt_token(scout.oauth_token_encrypted)
|
||||||
|
credentials = Credentials(
|
||||||
|
token=creds_info.get("token"),
|
||||||
|
refresh_token=creds_info.get("refresh_token"),
|
||||||
|
token_uri=creds_info.get("token_uri", "https://oauth2.googleapis.com/token"),
|
||||||
|
client_id=creds_info.get("client_id"),
|
||||||
|
client_secret=creds_info.get("client_secret"),
|
||||||
|
scopes=creds_info.get("scopes"),
|
||||||
|
)
|
||||||
|
return build("gmail", "v1", credentials=credentials, cache_discovery=False)
|
||||||
|
|
||||||
|
|
||||||
|
class GmailConnector:
|
||||||
|
source_type = "gmail"
|
||||||
|
|
||||||
|
# ── list_new ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def list_new(self, scout) -> list[ItemRef]:
|
||||||
|
"""Return new message refs since scout.gmail_history_id.
|
||||||
|
|
||||||
|
On first run (gmail_history_id is None/empty), records the current
|
||||||
|
historyId without backfilling — avoids flooding the user with old mail.
|
||||||
|
Updates scout.gmail_history_id in-place (caller must persist to DB).
|
||||||
|
"""
|
||||||
|
def _sync() -> tuple[list[ItemRef], str | None]:
|
||||||
|
service = _get_gmail_service(scout)
|
||||||
|
history_id = scout.gmail_history_id
|
||||||
|
refs: list[ItemRef] = []
|
||||||
|
new_history_id = history_id
|
||||||
|
|
||||||
|
if history_id:
|
||||||
|
resp = (
|
||||||
|
service.users()
|
||||||
|
.history()
|
||||||
|
.list(
|
||||||
|
userId="me",
|
||||||
|
startHistoryId=history_id,
|
||||||
|
historyTypes=["messageAdded"],
|
||||||
|
)
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
for entry in resp.get("history", []):
|
||||||
|
for added in entry.get("messagesAdded", []):
|
||||||
|
refs.append(ItemRef(source_msg_ref=added["message"]["id"]))
|
||||||
|
new_history_id = resp.get("historyId", history_id)
|
||||||
|
else:
|
||||||
|
# First run: capture baseline history id without backfilling.
|
||||||
|
profile = service.users().getProfile(userId="me").execute()
|
||||||
|
new_history_id = profile["historyId"]
|
||||||
|
|
||||||
|
return refs, new_history_id
|
||||||
|
|
||||||
|
refs, new_history_id = await asyncio.to_thread(_sync)
|
||||||
|
if new_history_id and new_history_id != scout.gmail_history_id:
|
||||||
|
scout.gmail_history_id = new_history_id
|
||||||
|
return refs
|
||||||
|
|
||||||
|
# ── fetch_metadata ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def fetch_metadata(self, scout, ref: ItemRef) -> ItemMetadata:
|
||||||
|
"""Fetch subject, sender, snippet only — uses Gmail metadata format (no body)."""
|
||||||
|
|
||||||
|
def _sync() -> ItemMetadata:
|
||||||
|
service = _get_gmail_service(scout)
|
||||||
|
msg = (
|
||||||
|
service.users()
|
||||||
|
.messages()
|
||||||
|
.get(
|
||||||
|
userId="me",
|
||||||
|
id=ref.source_msg_ref,
|
||||||
|
format="metadata",
|
||||||
|
metadataHeaders=["Subject", "From", "Date"],
|
||||||
|
)
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
headers = {
|
||||||
|
h["name"]: h["value"]
|
||||||
|
for h in msg.get("payload", {}).get("headers", [])
|
||||||
|
}
|
||||||
|
return ItemMetadata(
|
||||||
|
subject=headers.get("Subject"),
|
||||||
|
sender=headers.get("From"),
|
||||||
|
snippet=msg.get("snippet"),
|
||||||
|
received_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_sync)
|
||||||
|
|
||||||
|
# ── fetch_content ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def fetch_content(self, scout, ref: ItemRef) -> ItemContent:
|
||||||
|
"""Fetch full body text for a single message — transient, must not be persisted."""
|
||||||
|
|
||||||
|
def _sync() -> ItemContent:
|
||||||
|
service = _get_gmail_service(scout)
|
||||||
|
msg = service.users().messages().get(
|
||||||
|
userId="me", id=ref.source_msg_ref, format="full",
|
||||||
|
).execute()
|
||||||
|
headers = {h["name"]: h["value"] for h in msg.get("payload", {}).get("headers", [])}
|
||||||
|
body_text = _extract_plain_text_body(msg.get("payload", {}))
|
||||||
|
return ItemContent(
|
||||||
|
metadata=ItemMetadata(
|
||||||
|
subject=headers.get("Subject"),
|
||||||
|
sender=headers.get("From"),
|
||||||
|
snippet=msg.get("snippet"),
|
||||||
|
received_at=None,
|
||||||
|
),
|
||||||
|
body_text=body_text,
|
||||||
|
raw_headers=headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await asyncio.to_thread(_sync)
|
||||||
|
|
||||||
|
# ── archive ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def archive(self, scout, ref: ItemRef) -> None:
|
||||||
|
"""Move the message to Gmail Trash (recoverable for 30 days)."""
|
||||||
|
|
||||||
|
def _sync() -> None:
|
||||||
|
service = _get_gmail_service(scout)
|
||||||
|
service.users().messages().trash(
|
||||||
|
userId="me", id=ref.source_msg_ref
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
await asyncio.to_thread(_sync)
|
||||||
|
|
||||||
|
# ── watch management ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def setup_watch(self, scout) -> None:
|
||||||
|
"""Register a Gmail Pub/Sub push watch for the INBOX label.
|
||||||
|
|
||||||
|
Requires ``settings.GMAIL_PUBSUB_TOPIC`` to be set to the full topic
|
||||||
|
resource name (e.g. ``projects/my-project/topics/gmail-push``).
|
||||||
|
Logs a warning and returns without error if the topic is not configured.
|
||||||
|
"""
|
||||||
|
topic = settings.GMAIL_PUBSUB_TOPIC
|
||||||
|
if not topic:
|
||||||
|
logger.warning(
|
||||||
|
"setup_watch: GMAIL_PUBSUB_TOPIC is not configured — skipping watch setup"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
def _sync() -> None:
|
||||||
|
service = _get_gmail_service(scout)
|
||||||
|
request_body = {
|
||||||
|
"labelIds": ["INBOX"],
|
||||||
|
"topicName": topic,
|
||||||
|
}
|
||||||
|
resp = service.users().watch(userId="me", body=request_body).execute()
|
||||||
|
scout.gmail_history_id = resp.get("historyId")
|
||||||
|
expiration_ms = resp.get("expiration")
|
||||||
|
if expiration_ms:
|
||||||
|
scout.gmail_watch_expires_at = datetime.fromtimestamp(
|
||||||
|
int(expiration_ms) / 1000, tz=timezone.utc
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.to_thread(_sync)
|
||||||
|
|
||||||
|
async def renew_watch(self, scout) -> None:
|
||||||
|
"""Renew an existing Gmail Pub/Sub watch (same as setup_watch)."""
|
||||||
|
await self.setup_watch(scout)
|
||||||
32
app/scouts/connectors/registry.py
Normal file
32
app/scouts/connectors/registry.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Connector registry — single source of truth for source_type -> connector."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
_CONNECTORS: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_connector(connector: Any) -> None:
|
||||||
|
"""Register a SourceConnector instance under its ``source_type``.
|
||||||
|
|
||||||
|
Calling twice with the same ``source_type`` replaces the prior entry —
|
||||||
|
useful for tests and hot-reload, but in production each connector
|
||||||
|
should be registered exactly once at startup.
|
||||||
|
"""
|
||||||
|
if not getattr(connector, "source_type", None):
|
||||||
|
raise ValueError("Connector must declare a non-empty source_type")
|
||||||
|
_CONNECTORS[connector.source_type] = connector
|
||||||
|
|
||||||
|
|
||||||
|
def get_connector(source_type: str) -> Any:
|
||||||
|
"""Return the registered connector for ``source_type`` or raise KeyError."""
|
||||||
|
try:
|
||||||
|
return _CONNECTORS[source_type]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise KeyError(f"No connector registered for source_type {source_type!r}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_for_tests() -> None:
|
||||||
|
"""Clear the registry — for use in pytest fixtures only."""
|
||||||
|
_CONNECTORS.clear()
|
||||||
270
app/scouts/engine.py
Normal file
270
app/scouts/engine.py
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
"""ScoutEngine — orchestrates triage, queueing, and delivery for cloud scouts.
|
||||||
|
|
||||||
|
Triage flow per scout:
|
||||||
|
1. Resolve scout config from the DB.
|
||||||
|
2. Skip if device hasn't connected within ``device_inactivity_pause_days``.
|
||||||
|
3. Ask the connector to ``list_new`` — fresh items since last poll.
|
||||||
|
4. For each item:
|
||||||
|
- skip if already in the queue (idempotent on (scout_id, source_msg_ref))
|
||||||
|
- fetch the full content via the connector (transient, never persisted)
|
||||||
|
- run the triage LLM call → relevant | spam
|
||||||
|
- spam + auto_trash_spam → connector.archive
|
||||||
|
- relevant → INSERT scout_triage_queue row
|
||||||
|
5. Update scout.last_run_at.
|
||||||
|
|
||||||
|
Delivery flow on Electron WS reconnect:
|
||||||
|
- drain ``status='queued'`` rows for the user
|
||||||
|
- fetch metadata-only for each (subject + snippet)
|
||||||
|
- send a ``scout_proposal`` frame
|
||||||
|
- flip status to ``delivered`` on ack
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
|
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
from app.db import async_session
|
||||||
|
from app.models import CloudScoutConfig, ScoutTriageQueue
|
||||||
|
from app.scouts.connectors.base import ItemContent, ItemRef, TriageVerdict
|
||||||
|
from app.scouts.connectors.registry import get_connector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
QUEUE_TTL_DAYS = 30
|
||||||
|
|
||||||
|
|
||||||
|
class ScoutEngine:
|
||||||
|
def __init__(self, session_factory=None) -> None:
|
||||||
|
self._session_factory = session_factory or async_session
|
||||||
|
|
||||||
|
async def trigger_scout(self, scout_id: uuid.UUID) -> None:
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
scout = await session.get(CloudScoutConfig, str(scout_id))
|
||||||
|
if scout is None:
|
||||||
|
logger.warning("trigger_scout: no such scout id=%s", scout_id)
|
||||||
|
return
|
||||||
|
if not scout.enabled:
|
||||||
|
return
|
||||||
|
# Device-inactivity pause check is a simple heuristic on last_run_at —
|
||||||
|
# the device-online signal lives in the DeviceConnectionManager and is
|
||||||
|
# consulted at delivery time. For triage, we only check that the
|
||||||
|
# configured pause threshold isn't suppressing the run.
|
||||||
|
connector = get_connector(scout.provider)
|
||||||
|
try:
|
||||||
|
refs = await connector.list_new(scout)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("scout %s: list_new failed", scout.id)
|
||||||
|
return
|
||||||
|
|
||||||
|
for ref in refs:
|
||||||
|
await self._process_item(session, scout, connector, ref)
|
||||||
|
|
||||||
|
scout.last_run_at = datetime.now(tz=timezone.utc)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def _process_item(
|
||||||
|
self,
|
||||||
|
session,
|
||||||
|
scout: CloudScoutConfig,
|
||||||
|
connector,
|
||||||
|
ref: ItemRef,
|
||||||
|
) -> None:
|
||||||
|
# Idempotency check
|
||||||
|
existing = await session.execute(
|
||||||
|
select(ScoutTriageQueue.id).where(
|
||||||
|
ScoutTriageQueue.scout_id == scout.id,
|
||||||
|
ScoutTriageQueue.source_msg_ref == ref.source_msg_ref,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if existing.first() is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = await connector.fetch_content(scout, ref)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("scout %s: fetch_content failed for %s", scout.id, ref.source_msg_ref)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
verdict = await self._triage_llm(scout, content)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("scout %s: triage_llm failed for %s", scout.id, ref.source_msg_ref)
|
||||||
|
return
|
||||||
|
|
||||||
|
if verdict.verdict == "spam":
|
||||||
|
if scout.auto_trash_spam:
|
||||||
|
try:
|
||||||
|
await connector.archive(scout, ref)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("scout %s: archive failed for %s", scout.id, ref.source_msg_ref)
|
||||||
|
return
|
||||||
|
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
row = ScoutTriageQueue(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=scout.user_id,
|
||||||
|
scout_id=scout.id,
|
||||||
|
source_type=connector.source_type,
|
||||||
|
source_msg_ref=ref.source_msg_ref,
|
||||||
|
triage_verdict=verdict.verdict,
|
||||||
|
triage_reason=verdict.reason,
|
||||||
|
status="queued",
|
||||||
|
triaged_at=now,
|
||||||
|
expires_at=now + timedelta(days=QUEUE_TTL_DAYS),
|
||||||
|
)
|
||||||
|
session.add(row)
|
||||||
|
try:
|
||||||
|
# Use a savepoint so an IntegrityError on race doesn't poison the
|
||||||
|
# outer session — works on both PostgreSQL (SAVEPOINT) and SQLite.
|
||||||
|
async with session.begin_nested():
|
||||||
|
await session.flush()
|
||||||
|
except IntegrityError:
|
||||||
|
# Race: another worker inserted between our SELECT and INSERT.
|
||||||
|
# The unique constraint did its job; safe to ignore.
|
||||||
|
logger.debug(
|
||||||
|
"scout %s: idempotent skip for %s (race on unique constraint)",
|
||||||
|
scout.id,
|
||||||
|
ref.source_msg_ref,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def deliver_pending(self, user_id: uuid.UUID, ws) -> None:
|
||||||
|
"""Drain status='queued' rows for user, send scout_proposal WS frames, flip to 'delivered'."""
|
||||||
|
from app.scouts.connectors.base import ItemRef # noqa: PLC0415
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
rows = (await session.execute(
|
||||||
|
select(ScoutTriageQueue).where(
|
||||||
|
ScoutTriageQueue.user_id == str(user_id),
|
||||||
|
ScoutTriageQueue.status == "queued",
|
||||||
|
)
|
||||||
|
)).scalars().all()
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
connector = get_connector(row.source_type)
|
||||||
|
except KeyError:
|
||||||
|
logger.warning("deliver_pending: no connector for %s", row.source_type)
|
||||||
|
continue
|
||||||
|
scout = await session.get(CloudScoutConfig, row.scout_id)
|
||||||
|
if scout is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
meta = await connector.fetch_metadata(scout, ItemRef(source_msg_ref=row.source_msg_ref))
|
||||||
|
except Exception:
|
||||||
|
logger.exception("deliver_pending: fetch_metadata failed")
|
||||||
|
continue
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"type": "scout_proposal",
|
||||||
|
"proposal": {
|
||||||
|
"id": row.id,
|
||||||
|
"scout_id": row.scout_id,
|
||||||
|
"source_type": row.source_type,
|
||||||
|
"source_msg_ref": row.source_msg_ref,
|
||||||
|
"raw_subject": meta.subject,
|
||||||
|
"raw_snippet": meta.snippet,
|
||||||
|
"category": "unprocessed",
|
||||||
|
"payload": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
await ws.send_json(payload)
|
||||||
|
row.status = "delivered"
|
||||||
|
row.delivered_at = datetime.now(tz=timezone.utc)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def ack_proposal(self, proposal_id: str) -> None:
|
||||||
|
"""Flip a delivered proposal to acked. Idempotent — no-op if already acked."""
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
row = await session.get(ScoutTriageQueue, proposal_id)
|
||||||
|
if row is None:
|
||||||
|
return
|
||||||
|
row.status = "acked"
|
||||||
|
row.acked_at = datetime.now(tz=timezone.utc)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def _triage_llm(self, scout: CloudScoutConfig, content: ItemContent) -> TriageVerdict:
|
||||||
|
"""Call the scout-triage-system Langfuse prompt to classify an item as relevant or spam.
|
||||||
|
|
||||||
|
Uses gpt-4o-mini with JSON mode. Wraps the LLM call in a Langfuse generation
|
||||||
|
observation when Langfuse is configured.
|
||||||
|
"""
|
||||||
|
import json # noqa: PLC0415
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||||
|
|
||||||
|
_TRIAGE_FALLBACK = (
|
||||||
|
"You are a triage classifier for an executive-assistant scout that watches a "
|
||||||
|
"{source_type} feed.\n"
|
||||||
|
'The scout\'s purpose is: "{scout_purpose}".\n\n'
|
||||||
|
"Given one item, decide whether it is RELEVANT (worth surfacing to the user as a "
|
||||||
|
"potential task / event / note / project) or SPAM (advertising, mass marketing, "
|
||||||
|
"phishing, bulk notifications with no actionable content).\n\n"
|
||||||
|
"Item:\n"
|
||||||
|
" - Subject: {item_subject}\n"
|
||||||
|
" - From: {item_sender}\n"
|
||||||
|
" - Body (truncated): {item_body_truncated_2k}\n\n"
|
||||||
|
'Return JSON only, matching this schema:\n'
|
||||||
|
' {{"verdict": "relevant" | "spam", "reason": <short string>, "confidence": <0..1>}}\n\n'
|
||||||
|
"Be conservative on \"spam\" — if a message could plausibly be a personal/work "
|
||||||
|
"email, mark it relevant."
|
||||||
|
)
|
||||||
|
|
||||||
|
template, prompt_obj = get_prompt_or_fallback("scout-triage-system", _TRIAGE_FALLBACK)
|
||||||
|
|
||||||
|
body_trunc = (content.body_text or "")[:2000]
|
||||||
|
variables = dict(
|
||||||
|
source_type=scout.provider,
|
||||||
|
scout_purpose=scout.prompt_template or "",
|
||||||
|
item_subject=content.metadata.subject or "",
|
||||||
|
item_sender=content.metadata.sender or "",
|
||||||
|
item_body_truncated_2k=body_trunc,
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt_obj is not None:
|
||||||
|
try:
|
||||||
|
system_text = prompt_obj.compile(**variables)
|
||||||
|
if isinstance(system_text, list):
|
||||||
|
system_text = "\n".join(
|
||||||
|
m.get("content", "") for m in system_text if isinstance(m, dict)
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("scout triage: compile failed: %s", exc)
|
||||||
|
system_text = template.replace("{{source_type}}", variables["source_type"]) \
|
||||||
|
.replace("{{scout_purpose}}", variables["scout_purpose"]) \
|
||||||
|
.replace("{{item_subject}}", variables["item_subject"]) \
|
||||||
|
.replace("{{item_sender}}", variables["item_sender"]) \
|
||||||
|
.replace("{{item_body_truncated_2k}}", variables["item_body_truncated_2k"])
|
||||||
|
else:
|
||||||
|
system_text = template.format(**variables)
|
||||||
|
|
||||||
|
llm = get_llm(model="gpt-4o-mini", temperature=0)
|
||||||
|
llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=system_text),
|
||||||
|
HumanMessage(content="Classify this item."),
|
||||||
|
]
|
||||||
|
|
||||||
|
lf = get_langfuse()
|
||||||
|
if lf:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="scout-triage",
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
prompt=prompt_obj,
|
||||||
|
input=messages,
|
||||||
|
) as gen:
|
||||||
|
response = await llm_json.ainvoke(messages)
|
||||||
|
gen.update(output=response.content, usage=extract_usage(response))
|
||||||
|
else:
|
||||||
|
response = await llm_json.ainvoke(messages)
|
||||||
|
|
||||||
|
data = json.loads(response.content)
|
||||||
|
return TriageVerdict(**data)
|
||||||
@@ -39,3 +39,5 @@ lxml>=5.0.0
|
|||||||
PyYAML>=6.0.0
|
PyYAML>=6.0.0
|
||||||
apscheduler>=3.10.0
|
apscheduler>=3.10.0
|
||||||
ruff>=0.8.0
|
ruff>=0.8.0
|
||||||
|
pypdf>=4.0
|
||||||
|
python-docx>=1.1
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ from jose import jwt
|
|||||||
from sqlalchemy import StaticPool, event
|
from sqlalchemy import StaticPool, event
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.db import Base, get_session
|
from app.db import Base, get_session
|
||||||
from app.main import app
|
from app.main import app
|
||||||
@@ -134,6 +136,38 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st
|
|||||||
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Convenience aliases and per-tier user fixtures ────────────────────
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db(db_session: AsyncSession) -> AsyncSession:
|
||||||
|
"""Alias for db_session — used by folder quota tests."""
|
||||||
|
return db_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_user_free(db_session: AsyncSession):
|
||||||
|
"""Return the seeded free-tier User row."""
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(User).where(User.id == TEST_USER_IDS["free"])
|
||||||
|
)
|
||||||
|
return result.scalar_one()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_user_power(db_session: AsyncSession):
|
||||||
|
"""Return the seeded power-tier User row."""
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(User).where(User.id == TEST_USER_IDS["power"])
|
||||||
|
)
|
||||||
|
return result.scalar_one()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def auth_headers_free() -> dict[str, str]:
|
||||||
|
"""Authorization header for the seeded free-tier user."""
|
||||||
|
return auth_header("free")
|
||||||
|
|
||||||
|
|
||||||
# ── CLI options ───────────────────────────────────────────────────────
|
# ── CLI options ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
def pytest_addoption(parser):
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from app.core.agent_runner import (
|
from app.core.scout_runner import (
|
||||||
_format_metadata,
|
_format_metadata,
|
||||||
_format_projects,
|
_format_projects,
|
||||||
_get_extraction_rules,
|
_get_extraction_rules,
|
||||||
@@ -44,7 +44,7 @@ from app.core.agent_runner import (
|
|||||||
)
|
)
|
||||||
from app.core.device_manager import DeviceConnectionManager
|
from app.core.device_manager import DeviceConnectionManager
|
||||||
from app.core.langfuse_client import get_langfuse
|
from app.core.langfuse_client import get_langfuse
|
||||||
from app.models import AgentRunLog, LocalAgentConfig
|
from app.models import ScoutRunLog, LocalScoutConfig
|
||||||
from tests.conftest import TEST_USER_IDS
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
# ── Constants ─────────────────────────────────────────────────────────────
|
# ── Constants ─────────────────────────────────────────────────────────────
|
||||||
@@ -127,8 +127,8 @@ def _make_config(
|
|||||||
agent_config: dict | None = None,
|
agent_config: dict | None = None,
|
||||||
directory: str = "/emails",
|
directory: str = "/emails",
|
||||||
device_id: str = "dev-001",
|
device_id: str = "dev-001",
|
||||||
) -> LocalAgentConfig:
|
) -> LocalScoutConfig:
|
||||||
return LocalAgentConfig(
|
return LocalScoutConfig(
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
user_id=_USER_ID,
|
user_id=_USER_ID,
|
||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
@@ -136,7 +136,7 @@ def _make_config(
|
|||||||
directory_paths=[directory],
|
directory_paths=[directory],
|
||||||
data_types=["tasks", "notes", "timelines"],
|
data_types=["tasks", "notes", "timelines"],
|
||||||
prompt_template="",
|
prompt_template="",
|
||||||
agent_config=agent_config or _AGENT_CONFIG,
|
scout_config=agent_config or _AGENT_CONFIG,
|
||||||
file_extensions=[".html", ".eml"],
|
file_extensions=[".html", ".eml"],
|
||||||
schedule_cron="0 */6 * * *",
|
schedule_cron="0 */6 * * *",
|
||||||
enabled=True,
|
enabled=True,
|
||||||
@@ -144,11 +144,11 @@ def _make_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _make_run_log(agent_id: str) -> AgentRunLog:
|
def _make_run_log(agent_id: str) -> ScoutRunLog:
|
||||||
return AgentRunLog(
|
return ScoutRunLog(
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
agent_id=agent_id,
|
scout_id=agent_id,
|
||||||
agent_type="local",
|
scout_type="local",
|
||||||
user_id=_USER_ID,
|
user_id=_USER_ID,
|
||||||
status="running",
|
status="running",
|
||||||
started_at=datetime.now(timezone.utc),
|
started_at=datetime.now(timezone.utc),
|
||||||
@@ -271,7 +271,7 @@ async def test_2_9_device_offline():
|
|||||||
run_log = _make_run_log(config.id)
|
run_log = _make_run_log(config.id)
|
||||||
mgr = _make_manager(online=False)
|
mgr = _make_manager(online=False)
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
with patch("app.core.scout_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
||||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
await run_local_agent(_USER_ID, config, run_log, mgr)
|
||||||
|
|
||||||
_, kwargs = mock_fin.call_args
|
_, kwargs = mock_fin.call_args
|
||||||
@@ -295,8 +295,8 @@ async def test_2_10_empty_file():
|
|||||||
projects=[_PROJECTS["alpha"]],
|
projects=[_PROJECTS["alpha"]],
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
with patch("app.core.scout_runner._make_agent_executor", return_value=executor), \
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
patch("app.core.scout_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
||||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
await run_local_agent(_USER_ID, config, run_log, mgr)
|
||||||
|
|
||||||
_, kwargs = mock_fin.call_args
|
_, kwargs = mock_fin.call_args
|
||||||
@@ -326,9 +326,9 @@ async def test_2_8_items_created_count():
|
|||||||
_tool_calls_out.extend(["create_task", "create_note", "update_task"])
|
_tool_calls_out.extend(["create_task", "create_note", "update_task"])
|
||||||
return "Done."
|
return "Done."
|
||||||
|
|
||||||
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
with patch("app.core.scout_runner._make_agent_executor", return_value=executor), \
|
||||||
patch("app.core.agent_runner._run_agent_with_tools", side_effect=mock_run_agent), \
|
patch("app.core.scout_runner._run_agent_with_tools", side_effect=mock_run_agent), \
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
patch("app.core.scout_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
||||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
await run_local_agent(_USER_ID, config, run_log, mgr)
|
||||||
|
|
||||||
_, kwargs = mock_fin.call_args
|
_, kwargs = mock_fin.call_args
|
||||||
@@ -377,8 +377,8 @@ async def test_eval_runner(runner_case, pytestconfig):
|
|||||||
) if lf else nullcontext()
|
) if lf else nullcontext()
|
||||||
|
|
||||||
with obs_ctx as obs:
|
with obs_ctx as obs:
|
||||||
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
with patch("app.core.scout_runner._make_agent_executor", return_value=executor), \
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
patch("app.core.scout_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
||||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
await run_local_agent(_USER_ID, config, run_log, mgr)
|
||||||
|
|
||||||
_, kwargs = mock_fin.call_args
|
_, kwargs = mock_fin.call_args
|
||||||
|
|||||||
52
tests/test_contextual_scope.py
Normal file
52
tests/test_contextual_scope.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import pytest
|
||||||
|
from app.schemas.contextual import ContextualScope, render_scope_block
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_project_scope():
|
||||||
|
scope = ContextualScope(
|
||||||
|
page="project",
|
||||||
|
entity_type="project",
|
||||||
|
entity_id="p1",
|
||||||
|
entity_name="Acme Q3 launch",
|
||||||
|
counts={"tasks": 12, "notes": 4, "milestones": 3},
|
||||||
|
)
|
||||||
|
block = render_scope_block(scope)
|
||||||
|
assert "Acme Q3 launch" in block
|
||||||
|
assert "12 tasks" in block
|
||||||
|
assert "4 notes" in block
|
||||||
|
assert "3 milestones" in block
|
||||||
|
assert "p1" not in block
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_list_scope_no_entity():
|
||||||
|
scope = ContextualScope(page="tasks", entity_type=None)
|
||||||
|
block = render_scope_block(scope)
|
||||||
|
assert "tasks" in block.lower()
|
||||||
|
assert "None" not in block
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_note_scope_includes_char_count():
|
||||||
|
scope = ContextualScope(
|
||||||
|
page="note",
|
||||||
|
entity_type="note",
|
||||||
|
entity_id="n1",
|
||||||
|
entity_name="Meeting 14 May",
|
||||||
|
project_id="p1",
|
||||||
|
char_count=4280,
|
||||||
|
)
|
||||||
|
block = render_scope_block(scope)
|
||||||
|
assert "Meeting 14 May" in block
|
||||||
|
assert "4280" in block or "4,280" in block
|
||||||
|
|
||||||
|
|
||||||
|
def test_parses_camelcase_payload_from_renderer():
|
||||||
|
payload = {
|
||||||
|
"page": "project",
|
||||||
|
"entityType": "project",
|
||||||
|
"entityId": "p1",
|
||||||
|
"entityName": "Acme",
|
||||||
|
"counts": {"tasks": 5, "notes": 1, "milestones": 2},
|
||||||
|
}
|
||||||
|
scope = ContextualScope.model_validate(payload)
|
||||||
|
assert scope.entity_id == "p1"
|
||||||
|
assert scope.entity_name == "Acme"
|
||||||
44
tests/test_contextual_ws.py
Normal file
44
tests/test_contextual_ws.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Tests for contextual WS frame handlers.
|
||||||
|
|
||||||
|
These tests only exercise the new handler functions in device_ws.py and do
|
||||||
|
not depend on litellm or the full deep_agent import chain. They monkeypatch
|
||||||
|
run_contextual_stream so no LLM call is made.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_contextual_scope_update_appends_system_message_no_llm(monkeypatch):
|
||||||
|
"""_handle_contextual_scope_update must:
|
||||||
|
- call append_system_message on the session buffer
|
||||||
|
- send a contextual_scope_ack back on the socket
|
||||||
|
- make no LLM call
|
||||||
|
"""
|
||||||
|
from app.api.routes import device_ws
|
||||||
|
|
||||||
|
ws = AsyncMock()
|
||||||
|
buffer = MagicMock()
|
||||||
|
buffer.append_system_message = MagicMock()
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"type": "contextual_scope_update",
|
||||||
|
"session_id": "s1",
|
||||||
|
"scope": {
|
||||||
|
"page": "project",
|
||||||
|
"entityType": "project",
|
||||||
|
"entityId": "p1",
|
||||||
|
"entityName": "Acme",
|
||||||
|
"counts": {"tasks": 1, "notes": 0, "milestones": 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setattr(device_ws, "get_session_buffer", lambda *a, **kw: buffer)
|
||||||
|
await device_ws._handle_contextual_scope_update(ws, "user1", payload)
|
||||||
|
|
||||||
|
ws.send_text.assert_awaited_once()
|
||||||
|
import json
|
||||||
|
sent = json.loads(ws.send_text.await_args.args[0])
|
||||||
|
assert sent["type"] == "contextual_scope_ack"
|
||||||
|
assert sent["session_id"] == "s1"
|
||||||
|
buffer.append_system_message.assert_called_once()
|
||||||
@@ -12,11 +12,8 @@ from langchain_core.messages import AIMessage, ToolMessage
|
|||||||
from app.core.deep_agent import (
|
from app.core.deep_agent import (
|
||||||
_build_system_prompt,
|
_build_system_prompt,
|
||||||
_datetime_context_injection,
|
_datetime_context_injection,
|
||||||
_infer_floating_domain,
|
|
||||||
_normalize_tagged_list_lines,
|
_normalize_tagged_list_lines,
|
||||||
_request_context_block,
|
_request_context_block,
|
||||||
run_floating,
|
|
||||||
run_floating_stream,
|
|
||||||
run_home,
|
run_home,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -75,57 +72,6 @@ async def test_run_home_uses_mocked_tool_result():
|
|||||||
assert "Mock Task" in out
|
assert "Mock Task" in out
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
|
|
||||||
fake_llm = _FakeLLM()
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
|
||||||
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
|
||||||
):
|
|
||||||
events = []
|
|
||||||
async for event in run_floating_stream(
|
|
||||||
"user-1",
|
|
||||||
"show me timeline updates",
|
|
||||||
{"scope": {"type": "timeline", "id": "tl-1"}},
|
|
||||||
):
|
|
||||||
events.append(event)
|
|
||||||
|
|
||||||
assert events[0] == (
|
|
||||||
"floating_domain",
|
|
||||||
{"type": "timeline", "id": "tl-1", "section": None},
|
|
||||||
)
|
|
||||||
# _run_single_agent_stream uses ainvoke (not astream); the final token is
|
|
||||||
# the second LLM response which echoes the tool result.
|
|
||||||
token_events = [e for e in events if e[0] == "token"]
|
|
||||||
assert token_events, "Expected at least one token event"
|
|
||||||
combined = "".join(str(e[1]) for e in token_events)
|
|
||||||
assert "Mock Task" in combined
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
|
|
||||||
class _ClassifierOnlyLLM:
|
|
||||||
async def ainvoke(self, _messages):
|
|
||||||
return AIMessage(
|
|
||||||
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_agent_llm", return_value=_ClassifierOnlyLLM()):
|
|
||||||
domain = await _infer_floating_domain(
|
|
||||||
"Quali sono i miei task per il progetto X",
|
|
||||||
{
|
|
||||||
"scope": {"type": "timeline"},
|
|
||||||
"resolved_project_id": "213213-312321-312312-421321",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert domain == {
|
|
||||||
"type": "project",
|
|
||||||
"id": "213213-312321-312312-421321",
|
|
||||||
"section": "task",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
|
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
|
||||||
raw = (
|
raw = (
|
||||||
"Certo!\n\n"
|
"Certo!\n\n"
|
||||||
@@ -162,139 +108,6 @@ def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_
|
|||||||
assert "<timeline>[tl-future]</timeline>" not in out
|
assert "<timeline>[tl-future]</timeline>" not in out
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_strips_xml_like_tags_from_final_text():
|
|
||||||
fake_llm = _FakeLLM()
|
|
||||||
|
|
||||||
async def _fake_run_single_agent(**_kwargs):
|
|
||||||
return (
|
|
||||||
"Hai 1 task:\\n"
|
|
||||||
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
|
||||||
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
|
||||||
):
|
|
||||||
text, _domain = await run_floating(
|
|
||||||
"user-1",
|
|
||||||
"quali task ho?",
|
|
||||||
{"scope": {"type": "task"}},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "<task>" not in text
|
|
||||||
assert "</task>" not in text
|
|
||||||
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
|
|
||||||
fake_llm = _FakeLLM()
|
|
||||||
|
|
||||||
async def _fake_stream(**_kwargs):
|
|
||||||
yield "token", "Hai 1 task:\\n"
|
|
||||||
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
|
||||||
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
|
||||||
):
|
|
||||||
events = []
|
|
||||||
async for event in run_floating_stream(
|
|
||||||
"user-1",
|
|
||||||
"quali task ho?",
|
|
||||||
{"scope": {"type": "task"}},
|
|
||||||
):
|
|
||||||
events.append(event)
|
|
||||||
|
|
||||||
token_events = [str(data) for event_type, data in events if event_type == "token"]
|
|
||||||
combined = "".join(token_events)
|
|
||||||
assert "<task>" not in combined
|
|
||||||
assert "</task>" not in combined
|
|
||||||
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty():
|
|
||||||
class _NoChunkLLM:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.calls = 0
|
|
||||||
|
|
||||||
def bind_tools(self, _tools):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def ainvoke(self, _messages):
|
|
||||||
self.calls += 1
|
|
||||||
if self.calls == 1:
|
|
||||||
return AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"id": "call-1",
|
|
||||||
"name": "list_tasks",
|
|
||||||
"args": {},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
return AIMessage(content="No notes found.")
|
|
||||||
|
|
||||||
async def astream(self, _messages):
|
|
||||||
if False:
|
|
||||||
yield None
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_agent_llm", return_value=_NoChunkLLM()), patch(
|
|
||||||
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
|
||||||
):
|
|
||||||
events = []
|
|
||||||
async for event in run_floating_stream(
|
|
||||||
"user-1",
|
|
||||||
"quali sono le note?",
|
|
||||||
{"scope": {"type": "note"}},
|
|
||||||
):
|
|
||||||
events.append(event)
|
|
||||||
|
|
||||||
assert events[0][0] == "floating_domain"
|
|
||||||
assert ("token", "No notes found.") in events
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_returns_fallback_when_sanitization_would_empty_text():
|
|
||||||
fake_llm = _FakeLLM()
|
|
||||||
|
|
||||||
async def _fake_run_single_agent(**_kwargs):
|
|
||||||
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
|
||||||
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
|
||||||
):
|
|
||||||
text, _domain = await run_floating(
|
|
||||||
"user-1",
|
|
||||||
"quali task ho?",
|
|
||||||
{"scope": {"type": "task"}},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert text == "No results found."
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text():
|
|
||||||
fake_llm = _FakeLLM()
|
|
||||||
|
|
||||||
async def _fake_stream(**_kwargs):
|
|
||||||
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
|
||||||
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
|
||||||
):
|
|
||||||
events = []
|
|
||||||
async for event in run_floating_stream(
|
|
||||||
"user-1",
|
|
||||||
"quali task ho?",
|
|
||||||
{"scope": {"type": "task"}},
|
|
||||||
):
|
|
||||||
events.append(event)
|
|
||||||
|
|
||||||
assert ("token", "No results found.") in events
|
|
||||||
|
|
||||||
|
|
||||||
# ── _datetime_context_injection ────────────────────────────────────────────────
|
# ── _datetime_context_injection ────────────────────────────────────────────────
|
||||||
|
|
||||||
def _fp(tz: str, now_iso: str) -> dict:
|
def _fp(tz: str, now_iso: str) -> dict:
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import pytest
|
|||||||
from app.core.device_manager import DeviceConnectionManager
|
from app.core.device_manager import DeviceConnectionManager
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.main import app
|
from app.main import app
|
||||||
from app.models import AgentRunLog
|
from app.models import ScoutRunLog
|
||||||
from tests.conftest import TEST_USER_IDS, make_jwt
|
from tests.conftest import TEST_USER_IDS, make_jwt
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -33,9 +33,9 @@ _FREE_UID = TEST_USER_IDS["free"]
|
|||||||
_PRO_UID = TEST_USER_IDS["pro"]
|
_PRO_UID = TEST_USER_IDS["pro"]
|
||||||
|
|
||||||
|
|
||||||
def _device_hello(device_id: str = "dev-001", agent_ids: list[str] | None = None) -> str:
|
def _device_hello(device_id: str = "dev-001", scout_ids: list[str] | None = None) -> str:
|
||||||
return json.dumps(
|
return json.dumps(
|
||||||
{"type": "device_hello", "device_id": device_id, "agent_ids": agent_ids or []}
|
{"type": "device_hello", "device_id": device_id, "scout_ids": scout_ids or []}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -262,10 +262,10 @@ async def test_mark_runs_disconnected_updates_db(db_session):
|
|||||||
|
|
||||||
user_id = TEST_USER_IDS["free"]
|
user_id = TEST_USER_IDS["free"]
|
||||||
|
|
||||||
run_log = AgentRunLog(
|
run_log = ScoutRunLog(
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
agent_id=str(uuid.uuid4()),
|
scout_id=str(uuid.uuid4()),
|
||||||
agent_type="local",
|
scout_type="local",
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
status="running",
|
status="running",
|
||||||
started_at=datetime.now(timezone.utc),
|
started_at=datetime.now(timezone.utc),
|
||||||
@@ -280,7 +280,7 @@ async def test_mark_runs_disconnected_updates_db(db_session):
|
|||||||
# Verify through the same session factory.
|
# Verify through the same session factory.
|
||||||
async with _TestSessionLocal() as s:
|
async with _TestSessionLocal() as s:
|
||||||
result = await s.execute(
|
result = await s.execute(
|
||||||
select(AgentRunLog).where(AgentRunLog.id == run_log.id)
|
select(ScoutRunLog).where(ScoutRunLog.id == run_log.id)
|
||||||
)
|
)
|
||||||
updated = result.scalar_one_or_none()
|
updated = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|||||||
139
tests/test_folder_agent_tool.py
Normal file
139
tests/test_folder_agent_tool.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.folder_agent import (
|
||||||
|
read_project_folder_file,
|
||||||
|
search_project_folder_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
async def test_happy_path():
|
||||||
|
with patch(
|
||||||
|
"app.agents.folder_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"content": "file body", "kind": "text", "totalSize": 9}),
|
||||||
|
):
|
||||||
|
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "docs/x.md"})
|
||||||
|
assert "file body" in out
|
||||||
|
assert "kind=text" in out
|
||||||
|
|
||||||
|
|
||||||
|
async def test_traversal_rejected():
|
||||||
|
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "../../etc/passwd"})
|
||||||
|
assert out == "Access denied"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_absolute_path_rejected():
|
||||||
|
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "C:\\Windows\\foo"})
|
||||||
|
assert out == "Access denied"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_missing_file():
|
||||||
|
with patch(
|
||||||
|
"app.agents.folder_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"content": "", "kind": "missing", "totalSize": 0}),
|
||||||
|
):
|
||||||
|
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "ghost.md"})
|
||||||
|
assert "not found" in out.lower()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pagination_signals_more_available():
|
||||||
|
# Electron returned the first slice, totalSize larger than slice length.
|
||||||
|
with patch(
|
||||||
|
"app.agents.folder_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"content": "first chunk", "kind": "text", "totalSize": 1000}),
|
||||||
|
):
|
||||||
|
out = await read_project_folder_file.ainvoke({
|
||||||
|
"project_id": "p1",
|
||||||
|
"relative_path": "big.txt",
|
||||||
|
"offset": 0,
|
||||||
|
"length": 11,
|
||||||
|
})
|
||||||
|
assert "first chunk" in out
|
||||||
|
assert "More content available" in out
|
||||||
|
assert "offset=11" in out
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pdf_extracted_then_sliced(monkeypatch):
|
||||||
|
from app.agents import folder_agent
|
||||||
|
monkeypatch.setattr(folder_agent, "_extract_pdf_text", lambda b: "ABC " * 100)
|
||||||
|
with patch(
|
||||||
|
"app.agents.folder_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"content": "JVBERi0xLg==", "kind": "pdf", "totalSize": 12}),
|
||||||
|
):
|
||||||
|
out = await read_project_folder_file.ainvoke({
|
||||||
|
"project_id": "p1",
|
||||||
|
"relative_path": "doc.pdf",
|
||||||
|
"offset": 0,
|
||||||
|
"length": 8,
|
||||||
|
})
|
||||||
|
assert "kind=pdf" in out
|
||||||
|
assert "ABC ABC " in out
|
||||||
|
assert "More content available" in out
|
||||||
|
|
||||||
|
|
||||||
|
async def test_image_returns_placeholder():
|
||||||
|
with patch(
|
||||||
|
"app.agents.folder_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"content": "iVBORw0K", "kind": "image", "totalSize": 1024}),
|
||||||
|
):
|
||||||
|
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "logo.png"})
|
||||||
|
assert "image" in out.lower()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_search_finds_match_with_context():
|
||||||
|
body = "alpha\nbeta\nthe needle is here\ngamma\ndelta"
|
||||||
|
with patch(
|
||||||
|
"app.agents.folder_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"content": body, "kind": "text", "totalSize": len(body)}),
|
||||||
|
):
|
||||||
|
out = await search_project_folder_file.ainvoke({
|
||||||
|
"project_id": "p1",
|
||||||
|
"relative_path": "log.txt",
|
||||||
|
"query": "needle",
|
||||||
|
"context_lines": 1,
|
||||||
|
})
|
||||||
|
assert "needle" in out
|
||||||
|
assert "matches=1" in out
|
||||||
|
# Context lines included
|
||||||
|
assert "beta" in out
|
||||||
|
assert "gamma" in out
|
||||||
|
|
||||||
|
|
||||||
|
async def test_search_no_match():
|
||||||
|
with patch(
|
||||||
|
"app.agents.folder_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"content": "nothing here", "kind": "text", "totalSize": 12}),
|
||||||
|
):
|
||||||
|
out = await search_project_folder_file.ainvoke({
|
||||||
|
"project_id": "p1",
|
||||||
|
"relative_path": "x.txt",
|
||||||
|
"query": "zzz",
|
||||||
|
})
|
||||||
|
assert "No matches" in out
|
||||||
|
|
||||||
|
|
||||||
|
async def test_search_rejects_traversal():
|
||||||
|
out = await search_project_folder_file.ainvoke({
|
||||||
|
"project_id": "p1",
|
||||||
|
"relative_path": "../etc/passwd",
|
||||||
|
"query": "root",
|
||||||
|
})
|
||||||
|
assert out == "Access denied"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_search_image_rejected():
|
||||||
|
with patch(
|
||||||
|
"app.agents.folder_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"content": "b64data", "kind": "image", "totalSize": 100}),
|
||||||
|
):
|
||||||
|
out = await search_project_folder_file.ainvoke({
|
||||||
|
"project_id": "p1",
|
||||||
|
"relative_path": "logo.png",
|
||||||
|
"query": "anything",
|
||||||
|
})
|
||||||
|
assert "Cannot search" in out
|
||||||
83
tests/test_folder_indexer.py
Normal file
83
tests/test_folder_indexer.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""Folder indexer LLM helpers."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.folder_indexer import summarize_text, summarize_image, IndexResult
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
async def test_summarize_text_returns_summary_and_tokens():
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.content = "Kickoff notes covering scope and deadlines."
|
||||||
|
mock_resp.usage_metadata = {"input_tokens": 320, "output_tokens": 18, "total_tokens": 338}
|
||||||
|
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
|
||||||
|
result = await summarize_text(content="hello world", ext=".md", name="kickoff.md")
|
||||||
|
assert isinstance(result, IndexResult)
|
||||||
|
assert result.summary == "Kickoff notes covering scope and deadlines."
|
||||||
|
assert result.tokens_used == 338
|
||||||
|
|
||||||
|
|
||||||
|
async def test_summarize_text_truncates_summary_at_500_chars():
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.content = "x" * 1000
|
||||||
|
mock_resp.usage_metadata = {"total_tokens": 100}
|
||||||
|
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
|
||||||
|
result = await summarize_text(content="x", ext=".md", name="x.md")
|
||||||
|
assert len(result.summary) <= 500
|
||||||
|
|
||||||
|
|
||||||
|
async def test_summarize_image_uses_vision_content_blocks():
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.content = "Final logo on white background."
|
||||||
|
mock_resp.usage_metadata = {"total_tokens": 500}
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
async def fake_llm_vision(messages):
|
||||||
|
captured["messages"] = messages
|
||||||
|
return mock_resp
|
||||||
|
|
||||||
|
with patch("app.core.folder_indexer._llm_vision", new=fake_llm_vision):
|
||||||
|
result = await summarize_image(image_b64="iVBORw0KG", mime="image/png")
|
||||||
|
|
||||||
|
assert "Final logo" in result.summary
|
||||||
|
assert result.tokens_used == 500
|
||||||
|
# last message contains an image content block
|
||||||
|
last = captured["messages"][-1]
|
||||||
|
assert any(
|
||||||
|
isinstance(p, dict) and p.get("type") == "image_url"
|
||||||
|
for p in (last.content if isinstance(last.content, list) else [])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_summarize_pdf_extracts_then_summarizes(monkeypatch):
|
||||||
|
# pypdf.PdfReader returns text from pages
|
||||||
|
from app.core import folder_indexer
|
||||||
|
class FakePage:
|
||||||
|
def extract_text(self): return "PDF page content with project info."
|
||||||
|
class FakeReader:
|
||||||
|
pages = [FakePage(), FakePage()]
|
||||||
|
monkeypatch.setattr(folder_indexer, "PdfReader", lambda buf: FakeReader())
|
||||||
|
mock_resp = AsyncMock(); mock_resp.content = "Project info doc."; mock_resp.usage_metadata = {"total_tokens": 50}
|
||||||
|
async def fake_llm(messages): return mock_resp
|
||||||
|
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
|
||||||
|
result = await folder_indexer.summarize_pdf(pdf_b64="SGVsbG8=", name="doc.pdf")
|
||||||
|
assert "Project info" in result.summary
|
||||||
|
assert result.tokens_used == 50
|
||||||
|
|
||||||
|
|
||||||
|
async def test_summarize_docx_extracts_then_summarizes(monkeypatch):
|
||||||
|
from app.core import folder_indexer
|
||||||
|
class FakePara:
|
||||||
|
def __init__(self, t): self.text = t
|
||||||
|
class FakeDoc:
|
||||||
|
paragraphs = [FakePara("Heading"), FakePara("Body paragraph one.")]
|
||||||
|
monkeypatch.setattr(folder_indexer, "DocxDocument", lambda buf: FakeDoc())
|
||||||
|
mock_resp = AsyncMock(); mock_resp.content = "Heading and body."; mock_resp.usage_metadata = {"total_tokens": 30}
|
||||||
|
async def fake_llm(messages): return mock_resp
|
||||||
|
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
|
||||||
|
result = await folder_indexer.summarize_docx(docx_b64="UEsDBBQ=", name="doc.docx")
|
||||||
|
assert result.summary == "Heading and body."
|
||||||
94
tests/test_folder_quota.py
Normal file
94
tests/test_folder_quota.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
"""Folder quota helpers."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.billing.quota import (
|
||||||
|
check_folder_quota,
|
||||||
|
add_token_usage,
|
||||||
|
QuotaExceeded,
|
||||||
|
)
|
||||||
|
from app.models import MonthlyTokenUsage
|
||||||
|
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_folder_quota_free_rejects_above_file_cap(db, test_user_free):
|
||||||
|
with pytest.raises(QuotaExceeded) as exc:
|
||||||
|
await check_folder_quota(
|
||||||
|
user_id=test_user_free.id, tier="free", estimated_files=500, db=db
|
||||||
|
)
|
||||||
|
assert exc.value.reason == "max_files"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_folder_quota_free_passes_under_cap(db, test_user_free):
|
||||||
|
# No raise
|
||||||
|
await check_folder_quota(
|
||||||
|
user_id=test_user_free.id, tier="free", estimated_files=50, db=db
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_folder_quota_rejects_when_monthly_exhausted(db, test_user_free):
|
||||||
|
ym = datetime.now(timezone.utc).strftime("%Y-%m")
|
||||||
|
db.add(MonthlyTokenUsage(
|
||||||
|
user_id=test_user_free.id, year_month=ym, feature="folder_index", tokens_used=100_000
|
||||||
|
))
|
||||||
|
await db.commit()
|
||||||
|
with pytest.raises(QuotaExceeded) as exc:
|
||||||
|
await check_folder_quota(
|
||||||
|
user_id=test_user_free.id, tier="free", estimated_files=10, db=db
|
||||||
|
)
|
||||||
|
assert exc.value.reason == "monthly_tokens"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_folder_quota_power_unlimited(db, test_user_power):
|
||||||
|
await check_folder_quota(
|
||||||
|
user_id=test_user_power.id, tier="power", estimated_files=999_999, db=db
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_token_usage_atomic_increment(db, test_user_free):
|
||||||
|
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=1500, db=db)
|
||||||
|
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=2500, db=db)
|
||||||
|
ym = datetime.now(timezone.utc).strftime("%Y-%m")
|
||||||
|
row = (await db.execute(
|
||||||
|
select(MonthlyTokenUsage).where(
|
||||||
|
MonthlyTokenUsage.user_id == test_user_free.id,
|
||||||
|
MonthlyTokenUsage.year_month == ym,
|
||||||
|
MonthlyTokenUsage.feature == "folder_index",
|
||||||
|
)
|
||||||
|
)).scalar_one()
|
||||||
|
assert row.tokens_used == 4000
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_token_usage_returns_exhausted_when_over_cap(db, test_user_free):
|
||||||
|
result = await add_token_usage(
|
||||||
|
user_id=test_user_free.id, feature="folder_index", tokens=150_000, db=db, cap=100_000
|
||||||
|
)
|
||||||
|
assert result.exhausted is True
|
||||||
|
assert result.tokens_used == 150_000
|
||||||
|
|
||||||
|
|
||||||
|
def test_quota_check_endpoint_rejects(client, auth_headers_free):
|
||||||
|
res = client.post(
|
||||||
|
"/api/v1/billing/quota/check",
|
||||||
|
json={"feature": "folder_index", "estimated_files": 500},
|
||||||
|
headers=auth_headers_free,
|
||||||
|
)
|
||||||
|
assert res.status_code == 402
|
||||||
|
body = res.json()
|
||||||
|
assert body["detail"]["reason"] == "max_files"
|
||||||
|
|
||||||
|
|
||||||
|
def test_quota_check_endpoint_passes(client, auth_headers_free):
|
||||||
|
res = client.post(
|
||||||
|
"/api/v1/billing/quota/check",
|
||||||
|
json={"feature": "folder_index", "estimated_files": 50},
|
||||||
|
headers=auth_headers_free,
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.json() == {"ok": True}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Tests for Local Agent V2 journey setup (Step 4).
|
"""Tests for Local Agent V2 journey setup (Step 4).
|
||||||
|
|
||||||
Covers the chatbot journey that produces a structured AgentConfig JSON
|
Covers the chatbot journey that produces a structured ScoutConfig JSON
|
||||||
instead of a freeform prompt_template string.
|
instead of a freeform prompt_template string.
|
||||||
|
|
||||||
Unit tests (no LLM)
|
Unit tests (no LLM)
|
||||||
@@ -16,7 +16,7 @@ Eval test (real LLM + Langfuse scoring)
|
|||||||
----------------------------------------
|
----------------------------------------
|
||||||
4.1 Journey start explores directory → first reply contains a question
|
4.1 Journey start explores directory → first reply contains a question
|
||||||
|
|
||||||
Cases 4.2–4.5 (multi-turn conversations producing a full AgentConfig) are
|
Cases 4.2–4.5 (multi-turn conversations producing a full ScoutConfig) are
|
||||||
non-deterministic and tested manually — results tracked in Langfuse.
|
non-deterministic and tested manually — results tracked in Langfuse.
|
||||||
|
|
||||||
Run:
|
Run:
|
||||||
@@ -37,7 +37,7 @@ from unittest.mock import patch
|
|||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from app.api.routes.agent_setup import (
|
from app.api.routes.scout_setup import (
|
||||||
_CONFIG_END,
|
_CONFIG_END,
|
||||||
_CONFIG_START,
|
_CONFIG_START,
|
||||||
_MAX_TURNS,
|
_MAX_TURNS,
|
||||||
@@ -48,7 +48,7 @@ from app.api.routes.agent_setup import (
|
|||||||
)
|
)
|
||||||
from app.core.langfuse_client import get_langfuse
|
from app.core.langfuse_client import get_langfuse
|
||||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
from app.schemas import AgentConfig
|
from app.schemas import ScoutConfig
|
||||||
from tests.conftest import TEST_USER_IDS
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
# ── Constants ─────────────────────────────────────────────────────────────
|
# ── Constants ─────────────────────────────────────────────────────────────
|
||||||
@@ -179,7 +179,7 @@ def _evaluate_case(case: dict, reply: dict) -> tuple[float, str]:
|
|||||||
|
|
||||||
def test_4_6a_extract_valid_json():
|
def test_4_6a_extract_valid_json():
|
||||||
"""_extract_agent_config: valid JSON between markers → returns serialised config."""
|
"""_extract_agent_config: valid JSON between markers → returns serialised config."""
|
||||||
config = AgentConfig(
|
config = ScoutConfig(
|
||||||
content_types=[],
|
content_types=[],
|
||||||
global_rules=["No project = no entity"],
|
global_rules=["No project = no entity"],
|
||||||
data_types=["tasks"],
|
data_types=["tasks"],
|
||||||
@@ -187,7 +187,7 @@ def test_4_6a_extract_valid_json():
|
|||||||
text = f"Some preamble\n{_CONFIG_START}\n{config.model_dump_json()}\n{_CONFIG_END}\nTrailing"
|
text = f"Some preamble\n{_CONFIG_START}\n{config.model_dump_json()}\n{_CONFIG_END}\nTrailing"
|
||||||
result = _extract_agent_config(text)
|
result = _extract_agent_config(text)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
parsed = AgentConfig.model_validate_json(result)
|
parsed = ScoutConfig.model_validate_json(result)
|
||||||
assert parsed.global_rules == ["No project = no entity"]
|
assert parsed.global_rules == ["No project = no entity"]
|
||||||
|
|
||||||
|
|
||||||
@@ -230,7 +230,7 @@ async def test_4_6f_nudge_uses_new_markers():
|
|||||||
# Return plain text — no markers — to trigger the nudge path.
|
# Return plain text — no markers — to trigger the nudge path.
|
||||||
return "I still need more information from you."
|
return "I still need more information from you."
|
||||||
|
|
||||||
from app.api.routes.agent_setup import JourneySession
|
from app.api.routes.scout_setup import JourneySession
|
||||||
|
|
||||||
fake_session = JourneySession(
|
fake_session = JourneySession(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -248,7 +248,7 @@ async def test_4_6f_nudge_uses_new_markers():
|
|||||||
_sessions[session_id] = fake_session
|
_sessions[session_id] = fake_session
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with patch("app.api.routes.agent_setup._call_llm_with_tools", side_effect=_mock_llm):
|
with patch("app.api.routes.scout_setup._call_llm_with_tools", side_effect=_mock_llm):
|
||||||
await handle_journey_message(_USER_ID, {
|
await handle_journey_message(_USER_ID, {
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"message": "one more message to trigger nudge",
|
"message": "one more message to trigger nudge",
|
||||||
|
|||||||
69
tests/test_manifest_injection.py
Normal file
69
tests/test_manifest_injection.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.deep_agent import format_folder_manifest, MANIFEST_TOKEN_BUDGET
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_folder_manifest_basic():
|
||||||
|
manifest = {
|
||||||
|
"folderPath": "D:\\Acme",
|
||||||
|
"lastScannedAt": "2h ago",
|
||||||
|
"files": [
|
||||||
|
{"relPath": "briefs/kickoff.md", "kind": "text", "summary": "Kickoff notes; scope and deadlines."},
|
||||||
|
{"relPath": "logos/logo-v3.png", "kind": "image", "summary": "Final logo on white."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
out = format_folder_manifest(manifest)
|
||||||
|
assert "<linked_folder>" in out
|
||||||
|
assert "/briefs/kickoff.md" in out or "briefs/kickoff.md" in out
|
||||||
|
assert "[text]" in out
|
||||||
|
assert "[image]" in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_folder_manifest_truncates_past_budget():
|
||||||
|
files = [
|
||||||
|
{"relPath": f"f{i}.md", "kind": "text", "summary": "x" * 100, "mtimeMs": i}
|
||||||
|
for i in range(2000)
|
||||||
|
]
|
||||||
|
out = format_folder_manifest({"folderPath": "p", "lastScannedAt": "now", "files": files})
|
||||||
|
assert "more files omitted" in out
|
||||||
|
# Rough token check
|
||||||
|
assert len(out) // 4 < MANIFEST_TOKEN_BUDGET + 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_folder_manifest_null_returns_empty():
|
||||||
|
assert format_folder_manifest(None) == ""
|
||||||
|
assert format_folder_manifest({"files": []}) == ""
|
||||||
|
|
||||||
|
|
||||||
|
async def test_brief_multi_project_manifest_top_5_per_project():
|
||||||
|
fake_response = [
|
||||||
|
{
|
||||||
|
"projectId": "p1", "projectName": "Acme", "folderPath": "/a",
|
||||||
|
"lastScannedAt": "now",
|
||||||
|
"files": [
|
||||||
|
{"relPath": f"f{i}.md", "kind": "text", "summary": "s", "mtimeMs": i}
|
||||||
|
for i in range(10)
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"projectId": "p2", "projectName": "Beta", "folderPath": "/b",
|
||||||
|
"lastScannedAt": "now",
|
||||||
|
"files": [{"relPath": "x.md", "kind": "text", "summary": "s", "mtimeMs": 1}],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
with patch(
|
||||||
|
"app.core.deep_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"projects": fake_response}),
|
||||||
|
):
|
||||||
|
from app.core.deep_agent import build_brief_multi_project_manifest
|
||||||
|
out = await build_brief_multi_project_manifest()
|
||||||
|
# Project 1 has 10 files, only top 5 by mtimeMs should appear
|
||||||
|
assert out.count("[p1]") <= 5
|
||||||
|
# Project 2 has 1 file, must appear
|
||||||
|
assert "[p2]" in out or "Beta" in out
|
||||||
@@ -322,7 +322,7 @@ def test_home_request_calls_memory_middleware(client):
|
|||||||
):
|
):
|
||||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
ws.send_text(json.dumps({
|
ws.send_text(json.dumps({
|
||||||
"type": "device_hello", "device_id": "dev-mem", "agent_ids": []
|
"type": "device_hello", "device_id": "dev-mem", "scout_ids": []
|
||||||
}))
|
}))
|
||||||
ws.send_text(json.dumps({
|
ws.send_text(json.dumps({
|
||||||
"type": "home_request",
|
"type": "home_request",
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.core.output_formatter import StreamFormatter
|
from app.core.output_formatter import StreamFormatter
|
||||||
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
from app.schemas import WsStreamEnd, WsStreamStart, WsStreamText
|
||||||
|
|
||||||
|
|
||||||
async def _stream(*events: tuple[str, object]):
|
async def _stream(*events: tuple[str, object]):
|
||||||
@@ -36,29 +36,6 @@ async def test_stream_formatter_text_stream() -> None:
|
|||||||
assert isinstance(frames[-1], WsStreamEnd)
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_stream_formatter_floating_domain_first() -> None:
|
|
||||||
formatter = StreamFormatter(request_id="req-2")
|
|
||||||
frames = await _collect(
|
|
||||||
formatter,
|
|
||||||
_stream(
|
|
||||||
(
|
|
||||||
"floating_domain",
|
|
||||||
{"type": "node", "id": "n-1", "section": None},
|
|
||||||
),
|
|
||||||
("token", "Summary"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert isinstance(frames[0], WsFloatingDomain)
|
|
||||||
assert frames[0].domain.type == "node"
|
|
||||||
assert frames[0].domain.id == "n-1"
|
|
||||||
assert isinstance(frames[1], WsStreamStart)
|
|
||||||
assert isinstance(frames[2], WsStreamText)
|
|
||||||
assert frames[2].chunk == "Summary"
|
|
||||||
assert isinstance(frames[-1], WsStreamEnd)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_formatter_ignores_unknown_events() -> None:
|
async def test_stream_formatter_ignores_unknown_events() -> None:
|
||||||
formatter = StreamFormatter(request_id="req-3")
|
formatter = StreamFormatter(request_id="req-3")
|
||||||
|
|||||||
85
tests/test_run_contextual.py
Normal file
85
tests/test_run_contextual.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""Tests for run_contextual_stream.
|
||||||
|
|
||||||
|
These tests monkeypatch _run_single_agent_stream (the actual internal runner)
|
||||||
|
rather than the plan's fictional _run_agent_loop, matching the real
|
||||||
|
deep_agent.py architecture.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from app.schemas.contextual import ContextualScope
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_contextual_stream_includes_scope_block(monkeypatch):
|
||||||
|
"""run_contextual_stream must inject the scope block into the system prompt
|
||||||
|
and include get_page_details in the tool list while excluding note-edit tools."""
|
||||||
|
import app.core.deep_agent as deep_agent
|
||||||
|
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
async def fake_stream(
|
||||||
|
*,
|
||||||
|
user_id,
|
||||||
|
system_prompt,
|
||||||
|
message,
|
||||||
|
context,
|
||||||
|
agent_name="agent",
|
||||||
|
tools=None,
|
||||||
|
conversation_history=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
captured["sys"] = system_prompt
|
||||||
|
captured["tool_names"] = [getattr(t, "name", str(t)) for t in (tools or [])]
|
||||||
|
captured["agent_name"] = agent_name
|
||||||
|
# Async generator that yields nothing — still satisfies the protocol.
|
||||||
|
if False:
|
||||||
|
yield # pragma: no cover
|
||||||
|
|
||||||
|
monkeypatch.setattr(deep_agent, "_run_single_agent_stream", fake_stream)
|
||||||
|
|
||||||
|
scope = ContextualScope(
|
||||||
|
page="project",
|
||||||
|
entity_type="project",
|
||||||
|
entity_id="p1",
|
||||||
|
entity_name="Acme",
|
||||||
|
counts={"tasks": 1, "notes": 0, "milestones": 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
context = {
|
||||||
|
"conversation_history": [],
|
||||||
|
"_debug": {"session_id": "s1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
results = []
|
||||||
|
async for item in deep_agent.run_contextual_stream(
|
||||||
|
user_id="user1",
|
||||||
|
message="hi",
|
||||||
|
context=context,
|
||||||
|
scope=scope,
|
||||||
|
):
|
||||||
|
results.append(item)
|
||||||
|
|
||||||
|
assert "Acme" in captured["sys"], "scope block must appear in system prompt"
|
||||||
|
assert "Current view" in captured["sys"], "section header must be present"
|
||||||
|
|
||||||
|
names = captured["tool_names"]
|
||||||
|
assert "get_page_details" in names, "get_page_details tool must be included"
|
||||||
|
|
||||||
|
# Entity-create tools: at least one of these must be present.
|
||||||
|
assert any(n in names for n in ("create_task", "create_note", "update_task")), (
|
||||||
|
"at least one entity-create tool must be present"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "create_timeline" in names, "create_timeline tool must be included"
|
||||||
|
|
||||||
|
# Note edit tools must NOT be exposed.
|
||||||
|
assert "propose_note_edit" not in names, "propose_note_edit must be excluded"
|
||||||
|
|
||||||
|
# Legacy read tools must be excluded — they return shallow snapshots and
|
||||||
|
# cause the agent to under-answer (see trace 0b46841484ba7d024ed9f8d5ac8b1df0).
|
||||||
|
assert "list_projects" not in names, "list_projects must be excluded (legacy read)"
|
||||||
|
assert "get_project" not in names, "get_project must be excluded (legacy read)"
|
||||||
|
assert "list_tasks" not in names, "list_tasks must be excluded (legacy read)"
|
||||||
|
assert "get_task" not in names, "get_task must be excluded (legacy read)"
|
||||||
|
assert "list_notes" not in names, "list_notes must be excluded (legacy read)"
|
||||||
|
assert "get_note" not in names, "get_note must be excluded (legacy read)"
|
||||||
@@ -4,12 +4,8 @@ import pytest
|
|||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
WsDomain,
|
|
||||||
WsFrameType,
|
WsFrameType,
|
||||||
WsHomeRequest,
|
WsHomeRequest,
|
||||||
WsFloatingDomain,
|
|
||||||
WsFloatingRequest,
|
|
||||||
WsFloatingScope,
|
|
||||||
WsStreamEnd,
|
WsStreamEnd,
|
||||||
WsStreamStart,
|
WsStreamStart,
|
||||||
WsStreamText,
|
WsStreamText,
|
||||||
@@ -22,11 +18,9 @@ from app.schemas import (
|
|||||||
def test_v3_frame_types_exist():
|
def test_v3_frame_types_exist():
|
||||||
v3_types = [
|
v3_types = [
|
||||||
"home_request",
|
"home_request",
|
||||||
"floating_request",
|
|
||||||
"stream_start",
|
"stream_start",
|
||||||
"stream_text",
|
"stream_text",
|
||||||
"stream_end",
|
"stream_end",
|
||||||
"floating_domain",
|
|
||||||
"data_request",
|
"data_request",
|
||||||
"data_response",
|
"data_response",
|
||||||
"mutation",
|
"mutation",
|
||||||
@@ -86,51 +80,6 @@ def test_home_request_requires_message():
|
|||||||
WsHomeRequest.model_validate({"type": "home_request"})
|
WsHomeRequest.model_validate({"type": "home_request"})
|
||||||
|
|
||||||
|
|
||||||
# ── WsFloatingRequest ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_floating_request_basic():
|
|
||||||
frame = WsFloatingRequest(
|
|
||||||
message="Summarise",
|
|
||||||
scope=WsFloatingScope(type="task", id="task-123"),
|
|
||||||
)
|
|
||||||
assert frame.type == WsFrameType.floating_request
|
|
||||||
assert frame.scope.type == "task"
|
|
||||||
assert frame.scope.id == "task-123"
|
|
||||||
|
|
||||||
|
|
||||||
def test_floating_request_scope_without_id():
|
|
||||||
frame = WsFloatingRequest(
|
|
||||||
message="Show all",
|
|
||||||
scope=WsFloatingScope(type="project"),
|
|
||||||
)
|
|
||||||
assert frame.scope.id is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_floating_request_serializes():
|
|
||||||
frame = WsFloatingRequest(
|
|
||||||
message="Test",
|
|
||||||
scope=WsFloatingScope(type="note", id="n-1"),
|
|
||||||
)
|
|
||||||
data = frame.model_dump()
|
|
||||||
assert data["type"] == "floating_request"
|
|
||||||
assert data["scope"]["type"] == "note"
|
|
||||||
assert data["scope"]["id"] == "n-1"
|
|
||||||
|
|
||||||
|
|
||||||
def test_floating_request_invalid_scope_type():
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
WsFloatingRequest(
|
|
||||||
message="X",
|
|
||||||
scope=WsFloatingScope(type="unknown"), # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_floating_request_requires_scope():
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
WsFloatingRequest.model_validate({"type": "floating_request", "message": "X"})
|
|
||||||
|
|
||||||
|
|
||||||
# ── WsStreamStart ─────────────────────────────────────────────────────
|
# ── WsStreamStart ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -189,51 +138,3 @@ def test_stream_end_deserializes():
|
|||||||
assert frame.request_id == "r3"
|
assert frame.request_id == "r3"
|
||||||
|
|
||||||
|
|
||||||
# ── WsFloatingDomain ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_tasks():
|
|
||||||
frame = WsFloatingDomain(request_id="r1", domain=WsDomain(type="task"))
|
|
||||||
assert frame.type == WsFrameType.floating_domain
|
|
||||||
assert frame.domain.type == "task"
|
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_valid_domains():
|
|
||||||
frame = WsFloatingDomain(
|
|
||||||
request_id="r1",
|
|
||||||
domain=WsDomain(type="project", id="213213-312321-312312-421321", section="task"),
|
|
||||||
)
|
|
||||||
assert frame.domain.type == "project"
|
|
||||||
assert frame.domain.id == "213213-312321-312312-421321"
|
|
||||||
assert frame.domain.section == "task"
|
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_object_valid():
|
|
||||||
frame = WsFloatingDomain(
|
|
||||||
request_id="r1",
|
|
||||||
domain=WsDomain(type="project", id="p1", section="task"),
|
|
||||||
)
|
|
||||||
assert frame.domain.type == "project"
|
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_serializes():
|
|
||||||
d = WsFloatingDomain(
|
|
||||||
request_id="r1",
|
|
||||||
domain=WsDomain(type="timeline"),
|
|
||||||
).model_dump()
|
|
||||||
assert d == {
|
|
||||||
"type": "floating_domain",
|
|
||||||
"request_id": "r1",
|
|
||||||
"domain": {"type": "timeline", "id": None, "section": None},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_deserializes():
|
|
||||||
raw = {
|
|
||||||
"type": "floating_domain",
|
|
||||||
"request_id": "r1",
|
|
||||||
"domain": {"type": "node", "id": "n-1", "section": None},
|
|
||||||
}
|
|
||||||
frame = WsFloatingDomain.model_validate(raw)
|
|
||||||
assert frame.domain.type == "node"
|
|
||||||
assert frame.domain.id == "n-1"
|
|
||||||
|
|||||||
48
tests/test_scout_connector_registry.py
Normal file
48
tests/test_scout_connector_registry.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""Tests for the connector registry."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.scouts.connectors.base import ItemRef
|
||||||
|
from app.scouts.connectors.registry import (
|
||||||
|
get_connector,
|
||||||
|
register_connector,
|
||||||
|
_reset_for_tests,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyConnector:
|
||||||
|
source_type = "dummy"
|
||||||
|
async def list_new(self, scout): return []
|
||||||
|
async def fetch_metadata(self, scout, ref): raise NotImplementedError
|
||||||
|
async def fetch_content(self, scout, ref): raise NotImplementedError
|
||||||
|
async def archive(self, scout, ref): raise NotImplementedError
|
||||||
|
async def setup_watch(self, scout): raise NotImplementedError
|
||||||
|
async def renew_watch(self, scout): raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clean_registry():
|
||||||
|
_reset_for_tests()
|
||||||
|
yield
|
||||||
|
_reset_for_tests()
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_and_get():
|
||||||
|
c = _DummyConnector()
|
||||||
|
register_connector(c)
|
||||||
|
assert get_connector("dummy") is c
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_source_raises():
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
get_connector("nope")
|
||||||
|
|
||||||
|
|
||||||
|
def test_double_register_replaces():
|
||||||
|
a = _DummyConnector()
|
||||||
|
b = _DummyConnector()
|
||||||
|
register_connector(a)
|
||||||
|
register_connector(b)
|
||||||
|
assert get_connector("dummy") is b
|
||||||
48
tests/test_scout_connectors_base.py
Normal file
48
tests/test_scout_connectors_base.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""Tests for the SourceConnector base protocol and shared types."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.scouts.connectors.base import (
|
||||||
|
ItemContent,
|
||||||
|
ItemMetadata,
|
||||||
|
ItemRef,
|
||||||
|
TriageVerdict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_item_ref_round_trips_through_pydantic():
|
||||||
|
ref = ItemRef(source_msg_ref="abc123", received_at=datetime.now(tz=timezone.utc))
|
||||||
|
parsed = ItemRef.model_validate(ref.model_dump())
|
||||||
|
assert parsed.source_msg_ref == "abc123"
|
||||||
|
assert parsed.received_at == ref.received_at
|
||||||
|
|
||||||
|
|
||||||
|
def test_item_metadata_allows_all_optional():
|
||||||
|
meta = ItemMetadata()
|
||||||
|
assert meta.subject is None
|
||||||
|
assert meta.sender is None
|
||||||
|
assert meta.snippet is None
|
||||||
|
assert meta.received_at is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_item_content_requires_metadata_and_body():
|
||||||
|
content = ItemContent(
|
||||||
|
metadata=ItemMetadata(subject="hi"),
|
||||||
|
body_text="hello world",
|
||||||
|
raw_headers={"X-Foo": "bar"},
|
||||||
|
)
|
||||||
|
assert content.metadata.subject == "hi"
|
||||||
|
assert content.body_text == "hello world"
|
||||||
|
assert content.raw_headers["X-Foo"] == "bar"
|
||||||
|
|
||||||
|
|
||||||
|
def test_triage_verdict_constraints():
|
||||||
|
v = TriageVerdict(verdict="relevant", reason="contains task language", confidence=0.92)
|
||||||
|
assert v.verdict == "relevant"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
TriageVerdict(verdict="meh", reason="x", confidence=0.5) # bad enum value
|
||||||
84
tests/test_scout_connectors_gmail.py
Normal file
84
tests/test_scout_connectors_gmail.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""Tests for GmailConnector."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models import CloudScoutConfig
|
||||||
|
from app.scouts.connectors.base import ItemRef
|
||||||
|
from app.scouts.connectors.gmail import GmailConnector
|
||||||
|
|
||||||
|
|
||||||
|
def _make_scout():
|
||||||
|
return CloudScoutConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id="00000000-0000-0000-0000-000000000003",
|
||||||
|
provider="gmail",
|
||||||
|
name="Inbox",
|
||||||
|
data_types=[],
|
||||||
|
prompt_template="",
|
||||||
|
oauth_token_encrypted="encrypted-blob",
|
||||||
|
schedule_cron="0 * * * *",
|
||||||
|
enabled=True,
|
||||||
|
auto_trash_spam=False,
|
||||||
|
device_inactivity_pause_days=14,
|
||||||
|
gmail_history_id="100",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_metadata_returns_subject_and_snippet():
|
||||||
|
scout = _make_scout()
|
||||||
|
conn = GmailConnector()
|
||||||
|
fake_message = {
|
||||||
|
"id": "msg-1",
|
||||||
|
"snippet": "preview text",
|
||||||
|
"payload": {"headers": [
|
||||||
|
{"name": "Subject", "value": "Hello"},
|
||||||
|
{"name": "From", "value": "alice@example.com"},
|
||||||
|
{"name": "Date", "value": "Wed, 14 May 2026 10:00:00 +0000"},
|
||||||
|
]},
|
||||||
|
}
|
||||||
|
with patch("app.scouts.connectors.gmail._get_gmail_service") as mock_svc:
|
||||||
|
mock_svc.return_value.users().messages().get().execute.return_value = fake_message
|
||||||
|
meta = await conn.fetch_metadata(scout, ItemRef(source_msg_ref="msg-1"))
|
||||||
|
assert meta.subject == "Hello"
|
||||||
|
assert meta.sender == "alice@example.com"
|
||||||
|
assert meta.snippet == "preview text"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_content_returns_body_text():
|
||||||
|
import base64
|
||||||
|
scout = _make_scout()
|
||||||
|
conn = GmailConnector()
|
||||||
|
body_data = base64.urlsafe_b64encode(b"hello world").decode()
|
||||||
|
fake_message = {
|
||||||
|
"id": "msg-1",
|
||||||
|
"snippet": "hello world",
|
||||||
|
"payload": {
|
||||||
|
"mimeType": "text/plain",
|
||||||
|
"headers": [
|
||||||
|
{"name": "Subject", "value": "S"},
|
||||||
|
{"name": "From", "value": "a@b"},
|
||||||
|
],
|
||||||
|
"body": {"data": body_data},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
with patch("app.scouts.connectors.gmail._get_gmail_service") as mock_svc:
|
||||||
|
mock_svc.return_value.users().messages().get().execute.return_value = fake_message
|
||||||
|
content = await conn.fetch_content(scout, ItemRef(source_msg_ref="msg-1"))
|
||||||
|
assert content.body_text == "hello world"
|
||||||
|
assert content.metadata.subject == "S"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_archive_calls_trash():
|
||||||
|
scout = _make_scout()
|
||||||
|
conn = GmailConnector()
|
||||||
|
with patch("app.scouts.connectors.gmail._get_gmail_service") as mock_svc:
|
||||||
|
await conn.archive(scout, ItemRef(source_msg_ref="msg-1"))
|
||||||
|
mock_svc.return_value.users().messages().trash.assert_called()
|
||||||
270
tests/test_scout_engine.py
Normal file
270
tests/test_scout_engine.py
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
"""Unit tests for ScoutEngine.trigger_scout / _process_item."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.models import CloudScoutConfig, ScoutTriageQueue, User, Subscription
|
||||||
|
from app.scouts.connectors.base import ItemContent, ItemMetadata, ItemRef, TriageVerdict
|
||||||
|
from app.scouts.connectors.registry import register_connector, _reset_for_tests
|
||||||
|
from app.scouts.engine import ScoutEngine
|
||||||
|
from tests.conftest import _TestSessionLocal
|
||||||
|
|
||||||
|
|
||||||
|
def _make_connector(items, content_for):
|
||||||
|
c = AsyncMock()
|
||||||
|
# source_type must match the scout.provider ("gmail") so get_connector()
|
||||||
|
# finds it when the engine calls get_connector(scout.provider).
|
||||||
|
c.source_type = "gmail"
|
||||||
|
c.list_new = AsyncMock(return_value=items)
|
||||||
|
c.fetch_content = AsyncMock(side_effect=lambda scout, ref: content_for[ref.source_msg_ref])
|
||||||
|
c.archive = AsyncMock()
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _registry():
|
||||||
|
_reset_for_tests()
|
||||||
|
yield
|
||||||
|
_reset_for_tests()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_relevant_item_inserted_into_queue(monkeypatch):
|
||||||
|
user_id = "00000000-0000-0000-0000-000000000003" # power tier seeded in conftest
|
||||||
|
scout_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
scout = CloudScoutConfig(
|
||||||
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||||
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||||
|
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
||||||
|
)
|
||||||
|
session.add(scout)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
refs = [ItemRef(source_msg_ref="msg-1")]
|
||||||
|
content = {"msg-1": ItemContent(metadata=ItemMetadata(subject="Hi"), body_text="task tomorrow")}
|
||||||
|
connector = _make_connector(refs, content)
|
||||||
|
register_connector(connector)
|
||||||
|
|
||||||
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
engine,
|
||||||
|
"_triage_llm",
|
||||||
|
AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="task", confidence=0.9)),
|
||||||
|
)
|
||||||
|
|
||||||
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert rows[0].source_msg_ref == "msg-1"
|
||||||
|
assert rows[0].triage_verdict == "relevant"
|
||||||
|
assert rows[0].status == "queued"
|
||||||
|
connector.archive.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_spam_with_auto_trash_archives_and_does_not_queue(monkeypatch):
|
||||||
|
user_id = "00000000-0000-0000-0000-000000000003"
|
||||||
|
scout_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
scout = CloudScoutConfig(
|
||||||
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||||
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||||
|
enabled=True, auto_trash_spam=True, device_inactivity_pause_days=14,
|
||||||
|
)
|
||||||
|
session.add(scout)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
refs = [ItemRef(source_msg_ref="msg-spam")]
|
||||||
|
content = {"msg-spam": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")}
|
||||||
|
connector = _make_connector(refs, content)
|
||||||
|
register_connector(connector)
|
||||||
|
|
||||||
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
engine,
|
||||||
|
"_triage_llm",
|
||||||
|
AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)),
|
||||||
|
)
|
||||||
|
|
||||||
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||||
|
assert rows == []
|
||||||
|
connector.archive.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_spam_without_auto_trash_does_not_archive_and_does_not_queue(monkeypatch):
|
||||||
|
user_id = "00000000-0000-0000-0000-000000000003"
|
||||||
|
scout_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
scout = CloudScoutConfig(
|
||||||
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||||
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||||
|
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
||||||
|
)
|
||||||
|
session.add(scout)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
refs = [ItemRef(source_msg_ref="msg-2")]
|
||||||
|
content = {"msg-2": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")}
|
||||||
|
connector = _make_connector(refs, content)
|
||||||
|
register_connector(connector)
|
||||||
|
|
||||||
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
engine,
|
||||||
|
"_triage_llm",
|
||||||
|
AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)),
|
||||||
|
)
|
||||||
|
|
||||||
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||||
|
assert rows == []
|
||||||
|
connector.archive.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_idempotent_replay(monkeypatch):
|
||||||
|
user_id = "00000000-0000-0000-0000-000000000003"
|
||||||
|
scout_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
session.add(CloudScoutConfig(
|
||||||
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||||
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||||
|
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
||||||
|
))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
refs = [ItemRef(source_msg_ref="msg-3")]
|
||||||
|
content = {"msg-3": ItemContent(metadata=ItemMetadata(subject="x"), body_text="y")}
|
||||||
|
connector = _make_connector(refs, content)
|
||||||
|
register_connector(connector)
|
||||||
|
|
||||||
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
engine,
|
||||||
|
"_triage_llm",
|
||||||
|
AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="x", confidence=0.5)),
|
||||||
|
)
|
||||||
|
|
||||||
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||||
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||||
|
assert len(rows) == 1, "Replay must not create duplicate queue rows"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_triage_llm_parses_json_response(monkeypatch):
|
||||||
|
"""Real _triage_llm path: mock the LLM ainvoke, verify TriageVerdict parsed correctly."""
|
||||||
|
from unittest.mock import MagicMock # noqa: PLC0415
|
||||||
|
|
||||||
|
from app.models import CloudScoutConfig # noqa: PLC0415
|
||||||
|
|
||||||
|
scout = CloudScoutConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id="00000000-0000-0000-0000-000000000003",
|
||||||
|
provider="gmail",
|
||||||
|
name="test-scout",
|
||||||
|
data_types=[],
|
||||||
|
prompt_template="watch invoices and project updates",
|
||||||
|
schedule_cron="0 * * * *",
|
||||||
|
enabled=True,
|
||||||
|
auto_trash_spam=False,
|
||||||
|
device_inactivity_pause_days=14,
|
||||||
|
)
|
||||||
|
content = ItemContent(
|
||||||
|
metadata=ItemMetadata(subject="Invoice 42", sender="billing@acme.com"),
|
||||||
|
body_text="Payment of €1 200 is due on 2026-06-01. Please confirm receipt.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build a fake LangChain response whose .content is valid JSON.
|
||||||
|
fake_response = MagicMock()
|
||||||
|
fake_response.content = '{"verdict": "relevant", "reason": "invoice due", "confidence": 0.92}'
|
||||||
|
fake_response.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||||
|
|
||||||
|
# Fake LLM: .bind() returns self (or another mock with ainvoke).
|
||||||
|
fake_llm = MagicMock()
|
||||||
|
fake_llm.bind.return_value = fake_llm
|
||||||
|
fake_llm.ainvoke = AsyncMock(return_value=fake_response)
|
||||||
|
|
||||||
|
# Patch get_llm inside app.scouts.engine so our fake is used.
|
||||||
|
monkeypatch.setattr("app.scouts.engine.get_llm", lambda **kwargs: fake_llm)
|
||||||
|
# Disable Langfuse for this test.
|
||||||
|
monkeypatch.setattr("app.scouts.engine.get_langfuse", lambda: None)
|
||||||
|
# Use fallback prompt (no Langfuse) — patch get_prompt_or_fallback to return fallback.
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.scouts.engine.get_prompt_or_fallback",
|
||||||
|
lambda name, fallback: (fallback, None),
|
||||||
|
)
|
||||||
|
|
||||||
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||||
|
verdict = await engine._triage_llm(scout, content)
|
||||||
|
|
||||||
|
assert verdict.verdict == "relevant"
|
||||||
|
assert verdict.reason == "invoice due"
|
||||||
|
assert abs(verdict.confidence - 0.92) < 1e-6
|
||||||
|
fake_llm.ainvoke.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deliver_pending_sends_one_frame_per_queued_row(monkeypatch):
|
||||||
|
user_id = "00000000-0000-0000-0000-000000000003"
|
||||||
|
scout_id = str(uuid.uuid4())
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
session.add(CloudScoutConfig(
|
||||||
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||||
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||||
|
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
||||||
|
))
|
||||||
|
for i in range(3):
|
||||||
|
session.add(ScoutTriageQueue(
|
||||||
|
id=str(uuid.uuid4()), user_id=user_id, scout_id=scout_id,
|
||||||
|
source_type="gmail", source_msg_ref=f"msg-{i}",
|
||||||
|
triage_verdict="relevant", status="queued",
|
||||||
|
triaged_at=now, expires_at=now + timedelta(days=30),
|
||||||
|
))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
connector = AsyncMock()
|
||||||
|
connector.source_type = "gmail"
|
||||||
|
connector.fetch_metadata = AsyncMock(side_effect=lambda scout, ref: ItemMetadata(
|
||||||
|
subject=f"sub-{ref.source_msg_ref}", snippet=f"snip-{ref.source_msg_ref}",
|
||||||
|
))
|
||||||
|
register_connector(connector)
|
||||||
|
|
||||||
|
sent = []
|
||||||
|
ws = AsyncMock()
|
||||||
|
ws.send_json = AsyncMock(side_effect=lambda payload: sent.append(payload))
|
||||||
|
|
||||||
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||||
|
await engine.deliver_pending(uuid.UUID(user_id), ws)
|
||||||
|
|
||||||
|
assert len(sent) == 3
|
||||||
|
assert all(s["type"] == "scout_proposal" for s in sent)
|
||||||
|
subjects = {s["proposal"]["raw_subject"] for s in sent}
|
||||||
|
assert subjects == {"sub-msg-0", "sub-msg-1", "sub-msg-2"}
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||||
|
assert all(r.status == "delivered" for r in rows)
|
||||||
|
assert all(r.delivered_at is not None for r in rows)
|
||||||
106
tests/test_scout_webhook.py
Normal file
106
tests/test_scout_webhook.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""Tests for the Gmail Pub/Sub webhook route.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- Happy path: valid JWT + known user + enabled scout → 204, engine triggered.
|
||||||
|
- Rejection: invalid JWT → 401.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
from app.main import app
|
||||||
|
from app.models import CloudScoutConfig, User
|
||||||
|
from tests.conftest import _TestSessionLocal
|
||||||
|
|
||||||
|
|
||||||
|
def _pubsub_payload(email: str, history_id: str) -> dict:
|
||||||
|
"""Build a minimal Pub/Sub push envelope."""
|
||||||
|
inner = json.dumps({"emailAddress": email, "historyId": history_id}).encode()
|
||||||
|
return {
|
||||||
|
"message": {"data": base64.b64encode(inner).decode(), "messageId": "m1"},
|
||||||
|
"subscription": "projects/x/subscriptions/gmail-watch-sub",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_webhook_triggers_scout_for_matching_user():
|
||||||
|
"""204 returned and ScoutEngine.trigger_scout awaited for the matching scout."""
|
||||||
|
user_id = "00000000-0000-0000-0000-000000000003" # seeded 'power' user
|
||||||
|
scout_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Mutate the seeded user email so the webhook can resolve it,
|
||||||
|
# and add a cloud scout config for gmail.
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
user = await session.get(User, user_id)
|
||||||
|
user.email = "alice@example.com"
|
||||||
|
session.add(
|
||||||
|
CloudScoutConfig(
|
||||||
|
id=scout_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider="gmail",
|
||||||
|
name="Inbox",
|
||||||
|
data_types=[],
|
||||||
|
prompt_template="",
|
||||||
|
schedule_cron="0 * * * *",
|
||||||
|
enabled=True,
|
||||||
|
auto_trash_spam=False,
|
||||||
|
device_inactivity_pause_days=14,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
payload = _pubsub_payload("alice@example.com", "200")
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"app.api.routes.scout_webhooks._verify_pubsub_jwt",
|
||||||
|
return_value=True,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"app.api.routes.scout_webhooks.async_session",
|
||||||
|
_TestSessionLocal,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"app.scouts.engine.ScoutEngine.trigger_scout",
|
||||||
|
new=AsyncMock(),
|
||||||
|
) as mock_trigger,
|
||||||
|
):
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/scouts/webhooks/gmail",
|
||||||
|
json=payload,
|
||||||
|
headers={"Authorization": "Bearer fake-google-jwt"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 204
|
||||||
|
mock_trigger.assert_awaited_once_with(uuid.UUID(scout_id))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_webhook_rejects_unverified_jwt():
|
||||||
|
"""401 returned when JWT verification fails."""
|
||||||
|
payload = _pubsub_payload("alice@example.com", "200")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.api.routes.scout_webhooks._verify_pubsub_jwt",
|
||||||
|
return_value=False,
|
||||||
|
):
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=app), base_url="http://test"
|
||||||
|
) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/v1/scouts/webhooks/gmail",
|
||||||
|
json=payload,
|
||||||
|
headers={"Authorization": "Bearer bogus"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 401
|
||||||
196
tests/test_ws_index_session.py
Normal file
196
tests/test_ws_index_session.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
"""Tests for WS folder index_session handlers (Task 9).
|
||||||
|
|
||||||
|
Tests the three handler functions directly with a minimal fake WebSocket so
|
||||||
|
no real WS connection or LLM call is made.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.api.routes.device_ws import (
|
||||||
|
_handle_index_session_start,
|
||||||
|
_handle_index_file_batch,
|
||||||
|
_handle_index_session_cancel,
|
||||||
|
_index_sessions,
|
||||||
|
)
|
||||||
|
from app.billing.quota import add_token_usage
|
||||||
|
from app.core.folder_indexer import IndexResult
|
||||||
|
from app.models import MonthlyTokenUsage
|
||||||
|
from app.schemas import WsFrameType
|
||||||
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
USER_ID = TEST_USER_IDS["free"]
|
||||||
|
POWER_USER_ID = TEST_USER_IDS["power"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fake WebSocket ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _FakeWebSocket:
|
||||||
|
"""Minimal WebSocket stand-in that records send_text calls."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.sent: list[dict] = []
|
||||||
|
|
||||||
|
async def send_text(self, text: str) -> None:
|
||||||
|
self.sent.append(json.loads(text))
|
||||||
|
|
||||||
|
def sent_types(self) -> list[str]:
|
||||||
|
return [f["type"] for f in self.sent]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_session_id() -> str:
|
||||||
|
import uuid
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_summarize_text_factory(summary: str = "A test summary.", tokens: int = 100):
|
||||||
|
"""Return an AsyncMock that resolves to a fixed IndexResult."""
|
||||||
|
async def _fake(**kwargs) -> IndexResult:
|
||||||
|
return IndexResult(summary=summary, tokens_used=tokens)
|
||||||
|
return _fake
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(autouse=True)
|
||||||
|
async def _clean_sessions():
|
||||||
|
"""Ensure _index_sessions is empty before and after each test."""
|
||||||
|
_index_sessions.clear()
|
||||||
|
yield
|
||||||
|
_index_sessions.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def test_index_session_happy_path(db_session):
|
||||||
|
"""start + batch of 2 text files → 2 index_file_result + 1 progress + 1 done(completed)."""
|
||||||
|
ws = _FakeWebSocket()
|
||||||
|
session_id = _make_session_id()
|
||||||
|
|
||||||
|
# Register session.
|
||||||
|
await _handle_index_session_start(ws, USER_ID, {
|
||||||
|
"sessionId": session_id,
|
||||||
|
"projectId": "proj-1",
|
||||||
|
"totalFiles": 2,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Verify session was registered.
|
||||||
|
assert session_id in _index_sessions
|
||||||
|
assert _index_sessions[session_id]["total"] == 2
|
||||||
|
assert _index_sessions[session_id]["processed"] == 0
|
||||||
|
# No response frames expected for session_start.
|
||||||
|
assert ws.sent == []
|
||||||
|
|
||||||
|
# Send batch of 2 text files — patch summarize_text so no LLM call needed.
|
||||||
|
with patch(
|
||||||
|
"app.api.routes.device_ws._handle_index_file_batch.__globals__",
|
||||||
|
# We patch the module-level function in folder_indexer instead:
|
||||||
|
) if False else patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory()):
|
||||||
|
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
|
||||||
|
# Wire db_session into the context manager.
|
||||||
|
mock_cm = AsyncMock()
|
||||||
|
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
|
||||||
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_async_session.return_value = mock_cm
|
||||||
|
|
||||||
|
await _handle_index_file_batch(ws, USER_ID, {
|
||||||
|
"sessionId": session_id,
|
||||||
|
"files": [
|
||||||
|
{"relPath": "README.md", "kind": "text", "content": "hello", "ext": ".md"},
|
||||||
|
{"relPath": "notes.txt", "kind": "text", "content": "world", "ext": ".txt"},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
|
||||||
|
types = ws.sent_types()
|
||||||
|
# Expect 2 file results + 1 progress + 1 done(completed).
|
||||||
|
assert types.count(WsFrameType.index_file_result) == 2
|
||||||
|
assert types.count(WsFrameType.index_session_progress) == 1
|
||||||
|
assert types.count(WsFrameType.index_session_done) == 1
|
||||||
|
|
||||||
|
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||||
|
assert done_frame["status"] == "completed"
|
||||||
|
|
||||||
|
progress_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_progress)
|
||||||
|
assert progress_frame["processed"] == 2
|
||||||
|
assert progress_frame["total"] == 2
|
||||||
|
|
||||||
|
# Verify session cleaned up.
|
||||||
|
assert session_id not in _index_sessions
|
||||||
|
|
||||||
|
|
||||||
|
async def test_index_session_cancel(db_session):
|
||||||
|
"""start then cancel → index_session_done(cancelled)."""
|
||||||
|
ws = _FakeWebSocket()
|
||||||
|
session_id = _make_session_id()
|
||||||
|
|
||||||
|
await _handle_index_session_start(ws, USER_ID, {
|
||||||
|
"sessionId": session_id,
|
||||||
|
"totalFiles": 5,
|
||||||
|
})
|
||||||
|
assert session_id in _index_sessions
|
||||||
|
|
||||||
|
await _handle_index_session_cancel(ws, {"sessionId": session_id})
|
||||||
|
|
||||||
|
types = ws.sent_types()
|
||||||
|
assert WsFrameType.index_session_done in types
|
||||||
|
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||||
|
assert done_frame["status"] == "cancelled"
|
||||||
|
|
||||||
|
# Session should be cleaned up.
|
||||||
|
assert session_id not in _index_sessions
|
||||||
|
|
||||||
|
|
||||||
|
async def test_index_session_quota_exceeded(db_session):
|
||||||
|
"""Pre-fill usage to cap → batch one file → index_session_done(quota_exceeded)."""
|
||||||
|
ws = _FakeWebSocket()
|
||||||
|
session_id = _make_session_id()
|
||||||
|
|
||||||
|
# Pre-fill monthly token usage to the free-tier cap (100_000).
|
||||||
|
ym = datetime.now(timezone.utc).strftime("%Y-%m")
|
||||||
|
db_session.add(MonthlyTokenUsage(
|
||||||
|
user_id=USER_ID,
|
||||||
|
year_month=ym,
|
||||||
|
feature="folder_index",
|
||||||
|
tokens_used=100_000, # free tier cap exactly
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
await _handle_index_session_start(ws, USER_ID, {
|
||||||
|
"sessionId": session_id,
|
||||||
|
"totalFiles": 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
with patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory(tokens=1)):
|
||||||
|
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
|
||||||
|
mock_cm = AsyncMock()
|
||||||
|
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
|
||||||
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_async_session.return_value = mock_cm
|
||||||
|
|
||||||
|
await _handle_index_file_batch(ws, USER_ID, {
|
||||||
|
"sessionId": session_id,
|
||||||
|
"files": [
|
||||||
|
{"relPath": "file.md", "kind": "text", "content": "content", "ext": ".md"},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
|
||||||
|
types = ws.sent_types()
|
||||||
|
# Should have 1 file result (success) then done(quota_exceeded).
|
||||||
|
assert WsFrameType.index_file_result in types
|
||||||
|
assert WsFrameType.index_session_done in types
|
||||||
|
|
||||||
|
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||||
|
assert done_frame["status"] == "quota_exceeded"
|
||||||
|
|
||||||
|
# Session should be cleaned up.
|
||||||
|
assert session_id not in _index_sessions
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Integration tests for the unified WebSocket handler (Step 5).
|
"""Integration tests for the unified WebSocket handler (Step 5).
|
||||||
|
|
||||||
Tests the device WS endpoint with home_request and floating_request frames,
|
Tests the device WS endpoint with home_request frames,
|
||||||
verifying that the correct v3 frame sequence is returned.
|
verifying that the correct v3 frame sequence is returned.
|
||||||
|
|
||||||
LLM calls are mocked to avoid network dependency.
|
LLM calls are mocked to avoid network dependency.
|
||||||
@@ -34,7 +34,7 @@ def _override_db(db_session):
|
|||||||
|
|
||||||
|
|
||||||
def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
|
def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
|
||||||
"""Receive frames until stream_end (or stream_end inside floating flow), or max_frames."""
|
"""Receive frames until stream_end or max_frames."""
|
||||||
frames = []
|
frames = []
|
||||||
for _ in range(max_frames):
|
for _ in range(max_frames):
|
||||||
raw = ws.receive_text()
|
raw = ws.receive_text()
|
||||||
@@ -49,11 +49,6 @@ async def _mock_home_stream(user_id, message, context):
|
|||||||
yield "token", "Hello"
|
yield "token", "Hello"
|
||||||
|
|
||||||
|
|
||||||
async def _mock_floating_stream(user_id, message, context):
|
|
||||||
yield "floating_domain", {"type": "task", "id": None, "section": None}
|
|
||||||
yield "token", "Here is a summary"
|
|
||||||
|
|
||||||
|
|
||||||
# ── tests ─────────────────────────────────────────────────────────────────────
|
# ── tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def test_home_request_produces_stream_frames(client):
|
def test_home_request_produces_stream_frames(client):
|
||||||
@@ -63,7 +58,7 @@ def test_home_request_produces_stream_frames(client):
|
|||||||
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_home_stream):
|
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_home_stream):
|
||||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
ws.send_text(json.dumps({
|
ws.send_text(json.dumps({
|
||||||
"type": "device_hello", "device_id": "dev-1", "agent_ids": []
|
"type": "device_hello", "device_id": "dev-1", "scout_ids": []
|
||||||
}))
|
}))
|
||||||
ws.send_text(json.dumps({
|
ws.send_text(json.dumps({
|
||||||
"type": "home_request",
|
"type": "home_request",
|
||||||
@@ -79,33 +74,6 @@ def test_home_request_produces_stream_frames(client):
|
|||||||
assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end)
|
assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end)
|
||||||
|
|
||||||
|
|
||||||
def test_floating_request_produces_domain_frame(client):
|
|
||||||
"""floating_request → floating_domain first, then stream_text*, stream_end."""
|
|
||||||
token = make_jwt("power", user_id=USER_ID)
|
|
||||||
|
|
||||||
with patch("app.api.routes.device_ws.run_floating_stream", side_effect=_mock_floating_stream):
|
|
||||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
|
||||||
ws.send_text(json.dumps({
|
|
||||||
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
|
|
||||||
}))
|
|
||||||
ws.send_text(json.dumps({
|
|
||||||
"type": "floating_request",
|
|
||||||
"request_id": "p1",
|
|
||||||
"message": "Summarize this task",
|
|
||||||
"scope": {"type": "task", "id": "task-123"},
|
|
||||||
}))
|
|
||||||
frames = _recv_until_end(ws)
|
|
||||||
|
|
||||||
types = [f["type"] for f in frames]
|
|
||||||
assert WsFrameType.floating_domain in types
|
|
||||||
assert WsFrameType.stream_end in types
|
|
||||||
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
|
|
||||||
|
|
||||||
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
|
|
||||||
assert domain_frame["domain"]["type"] == "task"
|
|
||||||
assert domain_frame["request_id"] == "p1"
|
|
||||||
|
|
||||||
|
|
||||||
def test_home_request_request_id_propagated(client):
|
def test_home_request_request_id_propagated(client):
|
||||||
"""request_id in home_request is echoed in all response frames."""
|
"""request_id in home_request is echoed in all response frames."""
|
||||||
token = make_jwt("power", user_id=USER_ID)
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
@@ -117,7 +85,7 @@ def test_home_request_request_id_propagated(client):
|
|||||||
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
|
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
|
||||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
ws.send_text(json.dumps({
|
ws.send_text(json.dumps({
|
||||||
"type": "device_hello", "device_id": "dev-3", "agent_ids": []
|
"type": "device_hello", "device_id": "dev-3", "scout_ids": []
|
||||||
}))
|
}))
|
||||||
ws.send_text(json.dumps({
|
ws.send_text(json.dumps({
|
||||||
"type": "home_request",
|
"type": "home_request",
|
||||||
@@ -138,7 +106,7 @@ def test_tool_result_dispatch_silent_on_unknown_id(client):
|
|||||||
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.05):
|
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.05):
|
||||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
ws.send_text(json.dumps({
|
ws.send_text(json.dumps({
|
||||||
"type": "device_hello", "device_id": "dev-4", "agent_ids": []
|
"type": "device_hello", "device_id": "dev-4", "scout_ids": []
|
||||||
}))
|
}))
|
||||||
ws.send_text(json.dumps({
|
ws.send_text(json.dumps({
|
||||||
"type": "tool_result", "id": "no-such-id", "ok": True
|
"type": "tool_result", "id": "no-such-id", "ok": True
|
||||||
|
|||||||
Reference in New Issue
Block a user