21 Commits

Author SHA1 Message Date
Roberto
cc0e258e8c fix(api): WS index frames accept both camelCase and snake_case keys (Electron toSnakeCase compat) 2026-05-13 08:58:46 +02:00
Roberto
12e203e63d fix(api): multi-project manifest lists projects even with zero indexed files 2026-05-12 18:10:57 +02:00
Roberto
ffcd7390f0 feat(api): pagination + search + PDF/DOCX extract in folder agent tools 2026-05-12 17:31:43 +02:00
Roberto
91e880f9d4 fix(api): home agent falls back to multi-project folder manifest when no project_id 2026-05-12 16:54:47 +02:00
Roberto
7d47ca54be feat(api): emit Langfuse generation traces for folder indexer 2026-05-12 16:40:20 +02:00
Roberto
956fa88853 feat(api): multi-project folder manifest for daily brief
Add build_brief_multi_project_manifest() to deep_agent.py that fetches
all project folder manifests via execute_on_client and keeps the top 5
most-recently-modified files per project. Wire into run_home_brief in
brief_agent.py, injecting the <linked_folders> block into the system
prompt alongside FOLDER_TOOLS.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:40:47 +02:00
Roberto
fb2f59ccea feat(api): inject folder manifest into home agent when project context active
Add optional project_id param to run_home_stream. When set, fetch the linked
folder manifest via _fetch_project_manifest and prepend the <linked_folder>
block to the system prompt. Also build an explicit tools list that extends
_all_tools_for_user with FOLDER_TOOLS so the home agent can read folder
files. device_ws._handle_home_request extracts project_id / projectId from
the home_request frame and forwards it to the runner.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:32:20 +02:00
Roberto
56dbb7f4cd feat(api): inject folder manifest into task brief agent
Add _fetch_project_manifest helper that calls read_project_folder_manifest
via execute_on_client. Wire it into run_task_brief_research_stream (new
optional project_id param) so the <linked_folder> block is prepended to the
system prompt when the task belongs to a linked project. Also bind
FOLDER_TOOLS into the task-brief tool palette so the agent can read folder
files. device_ws extracts project_id / projectId from the task_brief_request
frame and forwards it.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:31:21 +02:00
Roberto
506f517851 feat(api): manifest formatter with token-budget truncation 2026-05-12 11:28:13 +02:00
Roberto
520c186991 feat(api): scoped read_project_folder_file tool with traversal guard
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:26:02 +02:00
Roberto
582bf27deb feat(api): WS index_session frames + handlers
Add six v7 WsFrameType enum members (index_session_start/cancel/batch,
index_file_result/progress/done), wire dispatch in device_ws message loop,
and implement _handle_index_session_start/cancel/file_batch with per-file
summarisation, token accounting, and quota enforcement.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:22:20 +02:00
Roberto
2aeb453229 feat(api): PDF + DOCX extraction in folder indexer
Add pypdf/python-docx deps, _extract_pdf_text/_extract_docx_text helpers,
and summarize_pdf/summarize_docx wrappers that delegate to summarize_text.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:15:17 +02:00
Roberto
b7a4edac90 feat(api): folder_indexer.summarize_image via gpt-4o-mini vision
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:09:37 +02:00
Roberto
822b4cd8b1 feat(api): folder_indexer.summarize_text via gpt-4o-mini 2026-05-12 11:05:43 +02:00
Roberto
ab24fc4c91 feat(api): POST /billing/quota/check endpoint
Pre-flight quota check for folder_index. Returns 402 with reason
when file cap or monthly token budget would be exceeded; 200 {"ok": true}
otherwise. Also adds auth_headers_free fixture to conftest.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 09:14:56 +02:00
Roberto
a98e99f7a2 feat(api): folder quota helpers with atomic token usage
Implements check_folder_quota and add_token_usage in app/billing/quota.py
with dialect-aware upsert (pg_insert on PostgreSQL, read-then-write on SQLite).
Adds test_user_free/test_user_power fixtures and db alias to conftest.py.
6 new tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 08:23:22 +02:00
Roberto
a0ff285bcd feat(api): tier features for folder integration
Add folder_max_files and folder_monthly_tokens to all four tier dicts
in FEATURES, and add get_feature_value() helper to TierManager.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:39:36 +02:00
Roberto
177c1a87dd feat(api): MonthlyTokenUsage model + AgentRunLog.tokens_used
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:30:33 +02:00
Roberto
441a4ea05c chore(api): fix stale Revises comment in folder migration 2026-05-12 07:21:13 +02:00
Roberto
a693a64bf5 feat(api): add migration for folder token tracking
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:16:23 +02:00
Roberto
67562b8092 Add task brief research agent: Stage 1 deep-research + canvas draft emission
- run_task_brief_research() runner with brief-specific tool set and max_steps=12
- New agents: client_agent (list_clients, get_client) and relations_agent (query_relations)
- search_associative tool wrapping MemoryMiddleware semantic search
- BRIEF_RESEARCH_TOOLS constant: read-only task/project/note/timeline + memory + client/relations
- canvas block extraction in output_formatter (splits visible text from <canvas> draft)
- device_ws.py: task_brief_research request type; emits canvas_draft mutation on stream_end
- Stage 2 briefMode: briefing_context injected into floating system prompt when present
- briefingContext kwarg wired through compile_prompt call chain

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-04 15:09:58 +02:00
24 changed files with 1960 additions and 9 deletions

View File

@@ -56,6 +56,10 @@ LLM_MODEL_CLOUD_PROCESSOR=
# A small model (e.g. gpt-4o-mini) is sufficient.
# LLM_MODEL_BRIEF_AGENT=
# Task-brief-agent — per-task deep research (Stage 1 executive assistant).
# Needs tool-use + reasoning; a capable model recommended (e.g. gpt-4o, gemini-2.5-flash).
# LLM_MODEL_TASK_BRIEF_AGENT=
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
LLM_MODEL_SETUP_AGENT=

View File

@@ -0,0 +1,46 @@
"""Add token tracking columns for folder integration.
Revision ID: d6e3f4a5b6c7
Revises: 006
Create Date: 2026-05-11 00:00:00.000000
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = "d6e3f4a5b6c7"
down_revision: Union[str, None] = "006"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"agent_run_logs",
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
)
op.create_table(
"monthly_token_usage",
sa.Column("user_id", UUID(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("year_month", sa.String(7), nullable=False),
sa.Column("feature", sa.String(64), nullable=False),
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
sa.PrimaryKeyConstraint("user_id", "year_month", "feature"),
)
op.create_index(
"ix_monthly_token_usage_user_month",
"monthly_token_usage",
["user_id", "year_month"],
)
def downgrade() -> None:
op.drop_index("ix_monthly_token_usage_user_month", table_name="monthly_token_usage")
op.drop_table("monthly_token_usage")
op.drop_column("agent_run_logs", "tokens_used")

View File

@@ -0,0 +1,52 @@
"""Client agent — read-only tools for the clients table."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
@tool
async def list_clients(search: str = "", limit: int = 20) -> str:
"""List clients, optionally filtered by a name/email substring search.
search: optional substring to match against client name or email.
limit: max rows to return (default 20).
"""
filters: dict[str, Any] = {"limit": limit}
if search:
filters["search"] = search
result = await execute_on_client(action="select", table="clients", filters=filters)
rows = result.get("rows", [])
if not rows:
return "No clients found."
lines = [
f"- {r.get('name', '?')} (id: {r.get('id')}, email: {r.get('email', '')}, "
f"company: {r.get('company', '')})"
for r in rows
]
return f"Found {len(rows)} client(s):\n" + "\n".join(lines)
@tool
async def get_client(id: str) -> str:
"""Get full details for one client by UUID.
id: the client's UUID.
"""
if not id:
return "Client id is required."
result = await execute_on_client(action="get", table="clients", data={"id": id})
row = result.get("row") or result.get("rows", [None])[0] if result else None
if not row:
return f"Client '{id}' not found."
return f"Client details:\n{json.dumps(row, ensure_ascii=False, indent=2)}"
CLIENT_TOOLS: list[Any] = [list_clients, get_client]

168
app/agents/folder_agent.py Normal file
View File

@@ -0,0 +1,168 @@
"""Scoped file-read and search tools for the project folder feature."""
from __future__ import annotations
from langchain_core.tools import tool
from app.core.folder_indexer import _extract_docx_text, _extract_pdf_text
from app.core.ws_context import execute_on_client
# Cap returned slice size to keep tool output under control.
_MAX_RETURN_CHARS = 50_000
_MAX_SEARCH_MATCHES = 20
def _is_unsafe_path(rel: str) -> bool:
if not rel:
return True
norm = rel.replace("\\", "/")
if norm.startswith("/"):
return True
# Windows drive letter
if len(rel) >= 2 and rel[1] == ":":
return True
parts = norm.split("/")
return ".." in parts
async def _fetch_file(project_id: str, relative_path: str, offset: int, length: int) -> dict:
"""Return the raw Electron tool_result dict for a file read."""
return await execute_on_client(
action="read_project_folder_file",
data={
"projectId": project_id,
"relativePath": relative_path,
"offset": offset,
"length": length,
},
)
def _decode(result: dict) -> tuple[str, str, int]:
"""Decode a tool_result into (text, kind, total_size). For pdf/docx,
extracts text from base64. For images, returns a placeholder string.
For text, content is already a sliced utf-8 string.
"""
kind = result.get("kind", "text")
content = result.get("content", "") or ""
total = int(result.get("totalSize", 0) or 0)
if kind == "image":
return ("[Image file — cannot be navigated as text. See manifest summary.]", kind, total)
if kind == "pdf":
return (_extract_pdf_text(content), kind, total)
if kind == "docx":
return (_extract_docx_text(content), kind, total)
return (content, kind, total)
@tool
async def read_project_folder_file(
project_id: str,
relative_path: str,
offset: int = 0,
length: int = _MAX_RETURN_CHARS,
) -> str:
"""Read a slice of a file inside the project's linked folder.
Args:
project_id: project ID.
relative_path: path relative to the linked folder root.
offset: char offset to start reading from (0 = beginning).
length: max chars to return. Default 50000. Use smaller values to save tokens.
Returns text content slice with a header showing position. Header tells you
when more content is available; call again with the suggested next offset.
For PDF / DOCX files the backend extracts text first, then applies offset/length
on the extracted text. For images returns a placeholder; navigate with the
manifest summary instead.
"""
if _is_unsafe_path(relative_path):
return "Access denied"
result = await _fetch_file(project_id, relative_path, offset, length)
text, kind, total_size = _decode(result)
if not text and kind in ("missing", "error"):
return f"File not found or unreadable: {relative_path}"
if kind in ("pdf", "docx"):
# Backend extracted full text — apply offset/length on chars.
sliced = text[offset:offset + length]
slice_end = min(offset + length, len(text))
header = (
f"[file={relative_path} kind={kind} offset={offset} end={slice_end} "
f"totalChars={len(text)}]"
)
if slice_end < len(text):
header += f"\n[More content available — call again with offset={slice_end}.]"
return header + "\n" + sliced
if kind == "text":
slice_end = offset + len(text)
header = (
f"[file={relative_path} kind=text offset={offset} end={slice_end} "
f"totalBytes={total_size}]"
)
if slice_end < total_size:
header += f"\n[More content available — call again with offset={slice_end}.]"
return header + "\n" + text
# image or unknown
return text
@tool
async def search_project_folder_file(
project_id: str,
relative_path: str,
query: str,
context_lines: int = 3,
) -> str:
"""Search a project folder file for a query string (case-insensitive substring).
Args:
project_id: project ID.
relative_path: path relative to the linked folder root.
query: text to search for.
context_lines: number of lines of context around each match (default 3).
Returns matching line ranges with surrounding context and 1-based line numbers.
Capped at 20 matches; if more exist the header shows the total.
Works on text, code, markdown, PDF (extracted), and DOCX (extracted).
Images and binary files are not searchable.
"""
if _is_unsafe_path(relative_path):
return "Access denied"
if not query:
return "Empty query."
# For text we still need full file; pass length=very large.
result = await _fetch_file(project_id, relative_path, offset=0, length=10_000_000)
text, kind, _ = _decode(result)
if not text and kind in ("missing", "error"):
return f"File not found or unreadable: {relative_path}"
if kind == "image":
return "Cannot search inside images."
lines = text.splitlines()
q = query.lower()
matches = [i for i, line in enumerate(lines) if q in line.lower()]
if not matches:
return f"No matches for '{query}' in {relative_path}."
shown = matches[:_MAX_SEARCH_MATCHES]
snippets: list[str] = []
for i in shown:
start = max(0, i - context_lines)
end = min(len(lines), i + context_lines + 1)
block = "\n".join(f"{n + 1:5d}: {lines[n]}" for n in range(start, end))
snippets.append(block)
header = f"[file={relative_path} matches={len(matches)} showing={len(shown)} query='{query}']"
body = "\n---\n".join(snippets)
return header + "\n" + body
FOLDER_TOOLS = [read_project_folder_file, search_project_folder_file]

View File

@@ -0,0 +1,63 @@
"""Relations agent — read-only tool wrapping MemoryMiddleware.query_relations."""
from __future__ import annotations
from typing import Any
from langchain_core.tools import tool
from app.core.memory_middleware import MemoryMiddleware
from app.db import async_session
# Injected at tool-factory time by _brief_research_tools(); not a module-level global.
# Each tool closure captures the user_id bound at factory time.
def make_query_relations_tool(user_id: str, trace_id: str | None = None) -> Any:
"""Return a query_relations tool bound to *user_id*."""
@tool
async def query_relations(
subject_label: str = "",
predicate: str = "",
object_label: str = "",
limit: int = 10,
) -> str:
"""Query the relational memory graph for entity relationships.
Returns rows where subject ↔ predicate ↔ object match the given filters.
All parameters are optional — omit to retrieve all relations up to limit.
subject_label: entity label on the left side (e.g. a client name, "Acme Corp").
predicate: relationship type (e.g. "mentioned_in", "works_at", "related_to").
object_label: entity label on the right side (e.g. a project name, "Website Redesign").
limit: max rows to return (default 10).
"""
import logging
logger = logging.getLogger(__name__)
logger.info(
"relations_agent: query_relations trace=%s user=%s subject=%r predicate=%r object=%r",
trace_id or "-", user_id, subject_label, predicate, object_label,
)
async with async_session() as db:
memory = MemoryMiddleware(db)
rows = await memory.query_relations(
user_id=user_id,
subject=subject_label or None,
predicate=predicate or None,
object_=object_label or None,
limit=limit,
)
if not rows:
return "No relational memory entries found for the given filters."
lines = [
f"- {r.subject_label} —[{r.predicate}]→ {r.object_label}"
+ (f" (confidence: {r.confidence:.2f})" if r.confidence is not None else "")
for r in rows
]
return f"Found {len(rows)} relation(s):\n" + "\n".join(lines)
return query_relations

View File

@@ -9,7 +9,7 @@ from __future__ import annotations
from typing import Any
from fastapi import APIRouter, Depends, Header, Request, status
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
@@ -96,3 +96,37 @@ async def list_invoices(
"""
invoices = await stripe_service.list_invoices(current_user.id, db)
return invoices
# ── Quota check ────────────────────────────────────────────────────────
from app.billing.quota import check_folder_quota, QuotaExceeded # noqa: E402
class QuotaCheckRequest(BaseModel):
feature: str
estimated_files: int
@router.post("/quota/check")
async def quota_check(
payload: QuotaCheckRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict:
"""Pre-flight folder quota check. 402 if tier limits would be exceeded."""
if payload.feature != "folder_index":
raise HTTPException(status_code=400, detail="Unknown feature")
try:
await check_folder_quota(
user_id=current_user.id,
tier=current_user.tier,
estimated_files=payload.estimated_files,
db=db,
)
except QuotaExceeded as exc:
raise HTTPException(
status_code=402,
detail={"reason": exc.reason, "message": str(exc)},
)
return {"ok": True}

View File

@@ -43,7 +43,8 @@ from app.api.routes.agent_setup import handle_journey_message, handle_journey_st
from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs
from app.core.brief_agent import run_home_brief, run_project_brief
from app.core.deep_agent import run_floating_stream, run_home_stream
from app.core.deep_agent import run_floating_stream, run_home_stream, run_task_brief_research_stream
from app.core.output_formatter import extract_canvas_block
from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware
from app.core.output_formatter import StreamFormatter
@@ -56,6 +57,10 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ws", tags=["device-ws"])
# ── v7 folder index session state ─────────────────────────────────────
# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled }
_index_sessions: dict[str, dict] = {}
_HEARTBEAT_INTERVAL = 30 # seconds
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
@@ -164,6 +169,11 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
_handle_brief_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.task_brief_request:
asyncio.create_task(
_handle_task_brief_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.journey_start:
asyncio.create_task(
_handle_journey_start(websocket, user_id, frame)
@@ -174,6 +184,19 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
_handle_journey_message(websocket, user_id, frame)
)
elif frame_type == WsFrameType.index_session_start:
asyncio.create_task(
_handle_index_session_start(websocket, user_id, frame)
)
elif frame_type == WsFrameType.index_file_batch:
asyncio.create_task(
_handle_index_file_batch(websocket, user_id, frame)
)
elif frame_type == WsFrameType.index_session_cancel:
await _handle_index_session_cancel(websocket, frame)
elif frame_type == "pong":
# Heartbeat ack — nothing to do, connection is alive.
pass
@@ -205,11 +228,13 @@ async def _handle_home_request(
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
logger.info(
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
"device_ws: home_request_start user=%s req=%s session=%s project=%s msg=%s",
user_id,
request_id,
session_id,
project_id,
message[:200],
)
@@ -234,7 +259,7 @@ async def _handle_home_request(
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_home_stream(user_id, message, context)
event_stream = run_home_stream(user_id, message, context, project_id=project_id)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
@@ -415,6 +440,98 @@ async def _handle_brief_request(
)
# ── v6 Task Brief Handler ────────────────────────────────────────────
async def _handle_task_brief_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a task_brief_request frame — Stage-1 executive assistant deep research.
Streams the briefing markdown back to the client.
On stream_end, emits a ``canvas_draft`` mutation if the agent produced one.
"""
request_id = frame.get("request_id") or str(uuid4())
session_id = frame.get("session_id") or str(uuid4())
task_id: str = frame.get("task_id") or frame.get("taskId") or ""
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
logger.info(
"device_ws: task_brief_request_start user=%s req=%s task=%s project=%s [cache_miss]",
user_id, request_id, task_id, project_id,
)
if not task_id:
await websocket.send_text(
WsStreamEnd(request_id=request_id, error="task_id is required").model_dump_json()
)
return
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id,
f"task brief: {task_id}",
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
"format_prefs": frame.get("format_prefs"),
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_task_brief_research_stream(user_id, task_id, context, project_id=project_id)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
await websocket.send_text(ws_frame.model_dump_json())
elif ws_frame.type == "stream_start":
await websocket.send_text(ws_frame.model_dump_json())
# stream_end is emitted below with mutations — skip formatter's version
except Exception as exc:
logger.error(
"device_ws: task_brief_request failed user=%s req=%s task=%s: %s",
user_id, request_id, task_id, exc,
)
await websocket.send_text(
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
)
return
finally:
clear_client_executor()
# Extract canvas block then emit stream_end with optional mutations.
full_response = "".join(response_chunks)
_visible, canvas_content, canvas_kind = extract_canvas_block(full_response)
mutations: list[dict] = []
if canvas_content:
mutations.append({
"type": "canvas_draft",
"content": canvas_content,
"kind": canvas_kind,
})
await websocket.send_text(
WsStreamEnd(request_id=request_id, mutations=mutations or None).model_dump_json()
)
logger.info(
"device_ws: task_brief_request_end user=%s req=%s task=%s response_chars=%d canvas=%s",
user_id, request_id, task_id, len(full_response), canvas_kind or "none",
)
# ── v4 Journey Handlers ─────────────────────────────────────────────
@@ -472,6 +589,174 @@ async def _handle_journey_message(
clear_client_executor()
# ── v7 Folder Index Handlers ──────────────────────────────────────────
async def _handle_index_session_start(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Register a new folder index session. No response sent — client is declaring intent."""
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
project_id: str | None = frame.get("projectId") or frame.get("project_id")
total: int = int(frame.get("totalFiles") or frame.get("total_files") or 0)
if not session_id:
logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id)
return
_index_sessions[session_id] = {
"user_id": user_id,
"project_id": project_id,
"processed": 0,
"total": total,
"cancelled": False,
}
logger.info(
"device_ws: index_session_start user=%s session=%s project=%s total=%d",
user_id, session_id, project_id, total,
)
async def _handle_index_session_cancel(
websocket: WebSocket,
frame: dict,
) -> None:
"""Mark a session as cancelled and emit index_session_done(cancelled)."""
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
session = _index_sessions.get(session_id)
if session:
session["cancelled"] = True
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "cancelled",
}))
_index_sessions.pop(session_id, None)
logger.info("device_ws: index_session_cancel session=%s", session_id)
async def _handle_index_file_batch(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Process a batch of files for an index session, streaming results back."""
# Lazy imports to avoid heavy load at module startup.
from app.core.folder_indexer import ( # noqa: PLC0415
summarize_image,
summarize_pdf,
summarize_docx,
summarize_text,
)
from app.billing.tier_manager import tier_manager # noqa: PLC0415
from app.billing.quota import add_token_usage # noqa: PLC0415
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
files: list[dict] = frame.get("files", [])
session = _index_sessions.get(session_id)
if not session or session.get("cancelled"):
return
async with async_session() as db:
tier = await tier_manager.get_tier(user_id, db)
raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens")
cap: int | None = None if raw_cap == -1 else raw_cap
for file_info in files:
if session.get("cancelled"):
return
# Electron's toSnakeCase converts payload keys, so accept both forms.
rel_path: str = file_info.get("relPath") or file_info.get("rel_path") or ""
kind: str = file_info.get("kind") or "text"
content: str = file_info.get("content") or ""
ext: str = file_info.get("ext") or ""
mime: str = file_info.get("mime") or "application/octet-stream"
name: str = rel_path.split("/")[-1] or rel_path
try:
if kind == "image":
res = await summarize_image(image_b64=content, mime=mime)
elif kind == "pdf":
res = await summarize_pdf(pdf_b64=content, name=name)
elif kind == "docx":
res = await summarize_docx(docx_b64=content, name=name)
else:
res = await summarize_text(content=content, ext=ext, name=name)
except Exception as exc:
logger.warning(
"device_ws: index_file_batch summarize failed session=%s path=%s: %s",
session_id, rel_path, exc,
)
await websocket.send_text(json.dumps({
"type": WsFrameType.index_file_result,
"sessionId": session_id,
"relPath": rel_path,
"summary": None,
"tokensUsed": 0,
"error": str(exc),
}))
session["processed"] += 1
continue
# Account for token usage and check cap.
usage = await add_token_usage(
user_id=user_id,
feature="folder_index",
tokens=res.tokens_used,
db=db,
cap=cap,
)
await websocket.send_text(json.dumps({
"type": WsFrameType.index_file_result,
"sessionId": session_id,
"relPath": rel_path,
"summary": res.summary,
"tokensUsed": res.tokens_used,
}))
session["processed"] += 1
if usage.exhausted:
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "quota_exceeded",
}))
_index_sessions.pop(session_id, None)
logger.info(
"device_ws: index_session quota_exceeded user=%s session=%s",
user_id, session_id,
)
return
# After processing the batch, emit progress.
processed = session["processed"]
total = session["total"]
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_progress,
"sessionId": session_id,
"processed": processed,
"total": total,
}))
if processed >= total:
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "completed",
}))
_index_sessions.pop(session_id, None)
logger.info(
"device_ws: index_session_done completed user=%s session=%s processed=%d",
user_id, session_id, processed,
)
# ── Heartbeat ─────────────────────────────────────────────────────────
async def _heartbeat_loop(websocket: WebSocket) -> None:

139
app/billing/quota.py Normal file
View File

@@ -0,0 +1,139 @@
"""Quota checks and atomic token-usage accounting for folder integration."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from sqlalchemy import select, update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from app.billing.tier_manager import TierManager
from app.models import MonthlyTokenUsage
from app.schemas import BillingTier
class QuotaExceeded(Exception):
"""Raised when a folder operation cannot proceed under the user's tier."""
def __init__(self, reason: str, message: str) -> None:
super().__init__(message)
self.reason = reason # "max_files" | "monthly_tokens"
@dataclass
class TokenUsageResult:
tokens_used: int
exhausted: bool
def _current_year_month() -> str:
return datetime.now(timezone.utc).strftime("%Y-%m")
_tier_manager = TierManager()
async def check_folder_quota(
*,
user_id: str,
tier: BillingTier,
estimated_files: int,
db: AsyncSession,
) -> None:
"""Raise QuotaExceeded if folder_max_files or folder_monthly_tokens
would be violated. -1 in either feature means unlimited."""
max_files = _tier_manager.get_feature_value(tier, "folder_max_files")
if max_files != -1 and estimated_files > max_files:
raise QuotaExceeded(
"max_files",
f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.",
)
cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens")
if cap == -1:
return
ym = _current_year_month()
row = (
await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == user_id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == "folder_index",
)
)
).scalar_one_or_none()
used = row.tokens_used if row else 0
if used >= cap:
raise QuotaExceeded(
"monthly_tokens",
f"Monthly token budget exhausted ({used}/{cap}); resets next month.",
)
async def add_token_usage(
*,
user_id: str,
feature: str,
tokens: int,
db: AsyncSession,
cap: int | None = None,
) -> TokenUsageResult:
"""Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature).
Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls
back to a read-then-write on other engines (e.g. aiosqlite in tests).
Returns post-update total and whether cap is exhausted.
"""
ym = _current_year_month()
# Detect dialect to choose between native upsert and portable fallback.
dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr]
if dialect_name == "postgresql":
# Native atomic upsert — production path.
stmt = (
pg_insert(MonthlyTokenUsage)
.values(
user_id=user_id,
year_month=ym,
feature=feature,
tokens_used=tokens,
)
.on_conflict_do_update(
index_elements=["user_id", "year_month", "feature"],
set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens},
)
.returning(MonthlyTokenUsage.tokens_used)
)
used: int = (await db.execute(stmt)).scalar_one()
await db.commit()
else:
# Portable fallback — used in tests (SQLite) and any non-PG engine.
row = (
await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == user_id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == feature,
)
)
).scalar_one_or_none()
if row is None:
row = MonthlyTokenUsage(
user_id=user_id,
year_month=ym,
feature=feature,
tokens_used=tokens,
)
db.add(row)
else:
row.tokens_used += tokens
await db.commit()
await db.refresh(row)
used = row.tokens_used
exhausted = cap is not None and cap != -1 and used >= cap
return TokenUsageResult(tokens_used=used, exhausted=exhausted)

View File

@@ -29,6 +29,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": False, # batch queue (Phase 2)
"relational_memory": False, # relational tier (Phase 3) — Pro+
"proactive_mining": False, # Power+ only (Phase 5)
"folder_max_files": 200,
"folder_monthly_tokens": 100_000,
},
"pro": {
"agents": -1, # unlimited
@@ -41,6 +43,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": True, # fire-and-forget asyncio.create_task
"relational_memory": True, # person/project predicates
"proactive_mining": False, # Power+ only (Phase 5)
"folder_max_files": 5000,
"folder_monthly_tokens": 2_000_000,
},
"power": {
"agents": -1,
@@ -53,6 +57,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": True,
"relational_memory": True, # all predicates incl. custom
"proactive_mining": True, # scheduled pattern mining (Phase 5)
"folder_max_files": -1, # unlimited
"folder_monthly_tokens": -1, # unlimited
},
"team": {
"agents": -1,
@@ -65,6 +71,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": True,
"relational_memory": True, # all predicates incl. custom
"proactive_mining": True, # scheduled pattern mining (Phase 5)
"folder_max_files": -1, # unlimited
"folder_monthly_tokens": -1, # unlimited
},
}
@@ -123,6 +131,13 @@ class TierManager:
)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def get_feature_value(self, tier: BillingTier, feature: str) -> int:
"""Return integer feature value for tier. -1 means unlimited."""
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
if not isinstance(value, int):
return 0
return value
# ── Rate limiting ────────────────────────────────────────────────────
def get_rate_limit(self, tier: BillingTier) -> int:

View File

@@ -29,6 +29,7 @@ class Settings(BaseSettings):
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs)
LLM_MODEL_TASK_BRIEF_AGENT: str = "" # task-brief-agent (per-task deep research)
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)

View File

@@ -21,6 +21,7 @@ from app.core.deep_agent import (
_relational_memory_injection,
_run_single_agent_stream,
_trace_id_from_context,
build_brief_multi_project_manifest,
)
from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback
@@ -159,6 +160,8 @@ async def run_home_brief(
Yields (event_type, data) tuples identical to _run_single_agent_stream.
Do NOT post-process output through _normalize_tagged_list_lines.
"""
from app.agents.folder_agent import FOLDER_TOOLS
trace_id = _trace_id_from_context(context)
today = date.today().isoformat()
language = _resolve_language(context)
@@ -171,7 +174,10 @@ async def run_home_brief(
if today not in system_prompt:
system_prompt += f"\nToday is {today}."
tools = _build_read_tools(user_id, trace_id)
brief_manifest = await build_brief_multi_project_manifest()
system_prompt = system_prompt + ("\n\n" + brief_manifest if brief_manifest else "")
tools = [*_build_read_tools(user_id, trace_id), *FOLDER_TOOLS]
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,

View File

@@ -12,8 +12,10 @@ from typing import Any, Literal
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool
from app.agents.client_agent import CLIENT_TOOLS
from app.agents.note_agent import NOTE_TOOLS
from app.agents.project_agent import PROJECT_TOOLS
from app.agents.relations_agent import make_query_relations_tool
from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS
from app.core.agent_session_buffer import session_buffer
@@ -58,6 +60,93 @@ def _language_instruction(context: dict[str, Any]) -> str:
f"All your output text must be written in {lang}."
)
MANIFEST_TOKEN_BUDGET = 3000 # rough budget for <linked_folder> block
def format_folder_manifest(manifest: dict | None) -> str:
"""Format a folder manifest into the <linked_folder> block.
Truncates by mtime DESC if estimated tokens exceed MANIFEST_TOKEN_BUDGET.
Returns empty string if manifest is None or has no files.
"""
if not manifest or not manifest.get("files"):
return ""
files = list(manifest["files"])
files.sort(key=lambda f: f.get("mtimeMs", 0), reverse=True)
header = (
f"<linked_folder>\npath: {manifest.get('folderPath', '?')} "
f"({len(files)} files, scanned {manifest.get('lastScannedAt', '?')})\nfiles:\n"
)
footer_template = "{} more files omitted, use read_project_folder_file to access by path\n</linked_folder>"
char_budget = MANIFEST_TOKEN_BUDGET * 4 # ~4 chars/token
body = ""
included = 0
for f in files:
line = f"- /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}\n"
if len(header) + len(body) + len(line) + len(footer_template.format(0)) > char_budget:
break
body += line
included += 1
omitted = len(files) - included
if omitted > 0:
return header + body + footer_template.format(omitted)
return header + body + "</linked_folder>"
async def _fetch_project_manifest(project_id: str) -> dict | None:
"""Fetch manifest from Electron via execute_on_client. Returns None if unlinked or error."""
from app.core.ws_context import execute_on_client
try:
result = await execute_on_client(
action="read_project_folder_manifest",
data={"projectId": project_id},
)
if not result or not result.get("folderPath"):
return None
return result
except Exception:
return None
async def build_brief_multi_project_manifest() -> str:
"""Build a compact multi-project manifest for the daily brief agent.
Calls execute_on_client('list_projects_with_folder_manifests') and keeps
the top 5 most-recently-modified files per project.
"""
try:
result = await execute_on_client(
action="list_projects_with_folder_manifests",
data={},
)
except Exception:
return ""
projects = (result or {}).get("projects") or []
if not projects:
return ""
blocks: list[str] = ["<linked_folders>"]
any_entry = False
for p in projects:
all_files = p.get("files", []) or []
files = sorted(all_files, key=lambda f: f.get("mtimeMs", 0), reverse=True)[:5]
blocks.append(f"project: {p.get('projectName','?')} [{p.get('projectId','?')}]")
blocks.append(f" path: {p.get('folderPath','?')} (scanned {p.get('lastScannedAt','?')})")
if not all_files:
blocks.append(" (no indexed files yet — folder is linked but empty or unscanned)")
else:
for f in files:
blocks.append(f" - /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}")
if len(all_files) > 5:
blocks.append(f"{len(all_files) - 5} more files (use read_project_folder_file by relPath)")
any_entry = True
if not any_entry:
return ""
blocks.append("</linked_folders>")
return "\n".join(blocks)
def _datetime_context_injection(context: dict[str, Any]) -> str:
"""Build a comprehensive DATE CONTEXT block with pre-computed ms-epoch boundaries for common ranges."""
fp = context.get("format_prefs")
@@ -303,6 +392,80 @@ For specific dates not listed, compute local-midnight in the user timezone and c
{request_context}\
"""
_TASK_BRIEF_RESEARCH_SYSTEM_PROMPT = """\
You are an executive assistant preparing a briefing dossier for your principal before they act on a specific task.
Your job: gather all relevant context, synthesize it into a tight actionable dossier, and — if the task requires writing (email, message, document) — produce a ready-to-use draft.{user_identity}
# Research workflow
Follow these steps in order, using tools:
1. Read the task fully (title, description, due date, priority, status, project, comments).
2. Fetch the parent project (`get_project`) to understand scope, aiSummary, and any linked client.
3. If the project has a clientId: call `get_client(id)` to retrieve full client details.
4. Call `query_relations` (subject_label=client_name or task subject) to find cross-project connections — e.g. the same client appearing in multiple projects.
5. Search associative memory (`search_associative`) and archival memory (`archival_memory_search`) using the task title + client name as query phrases to surface relevant past interactions.
6. Read core memory blocks for tone preference, language, and user style: `memory_get("tone_preference")`, `memory_get("language")`.
7. Determine task kind: is this a writing task (email reply, message, follow-up, proposal)? If yes, draft a ready-to-send piece.
# Output structure
Write the briefing in the user's language. Use this exact structure:
**What needs to be done**
(12 sentences, concrete and specific — what action the user must take)
**Context you should know**
(bullet points covering: client background, related projects, prior interactions, tone/style notes, any relevant deadlines or dependencies)
**Suggested first step**
(one specific, immediately actionable instruction)
If this is a writing task, append a canvas block at the very end:
<canvas kind="email|document|message">
...ready-to-use draft here...
</canvas>
Do NOT include the canvas block for non-writing tasks.
Do NOT repeat verbatim task fields the user already sees in the UI.
Be concrete — no vague advice. Every bullet should be a fact that changes what the user does.
# Date context
{date_context}
# Language
{language_instruction}
# Known people & projects
{relational_memory}
# Request context
{request_context}\
"""
_TASK_BRIEF_FOLLOWUP_SYSTEM_PROMPT = """\
You are an executive assistant continuing a conversation with your principal.
You have already prepared and delivered a research briefing for the active task. The user has read it.{user_identity}
Your briefing:
---
{briefing_context}
---
Continue from here. Do NOT repeat the briefing. Refer to it when relevant.
Help the user execute: edit drafts, refine wording, look up additional details, plan next steps.
Stay terse — your principal is a busy executive.
# Date context
{date_context}
# Language
{language_instruction}
# Known people & projects
{relational_memory}
# Request context
{request_context}\
"""
_FLOATING_DOMAIN_CLASSIFIER_PROMPT = (
"You are a strict domain classifier for websocket floating requests. "
"Return ONLY a JSON object with keys: type, id, section. "
@@ -679,6 +842,25 @@ def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
lines = [f"- {item}" for item in results]
return "Recall memory results:\n" + "\n".join(lines)
@tool
async def search_associative(query: str, limit: int = 5) -> str:
"""Semantic search across associative (archival) memory for a given query.
Use this to surface long-term memories related to a topic, client, or task
that may not appear in recent episodes.
query: natural-language search phrase.
limit: max results (default 5).
"""
logger.info("deep_agent: search_associative trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
async with async_session() as db:
memory = MemoryMiddleware(db)
results = await memory.search_archival(user_id, query, top_k=limit)
if not results:
return "No associative memory results found."
lines = [f"- {item}" for item in results]
return "Associative memory results:\n" + "\n".join(lines)
return [
memory_list_blocks,
memory_get,
@@ -689,16 +871,33 @@ def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
archival_memory_insert,
archival_memory_search,
conversation_search,
search_associative,
]
def _read_only_memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
"""Return memory tools that only read — safe for the read-only brief-agent subset."""
all_mem = _memory_tools(user_id, trace_id)
_read_names = {"memory_list_blocks", "memory_get", "archival_memory_search", "conversation_search"}
_read_names = {
"memory_list_blocks", "memory_get", "archival_memory_search",
"conversation_search", "search_associative",
}
return [t for t in all_mem if t.name in _read_names]
def _brief_research_tools(user_id: str, trace_id: str | None) -> list[Any]:
"""Return the full tool palette for Stage-1 task brief research (read-only)."""
return [
*TASK_TOOLS,
*PROJECT_TOOLS,
*NOTE_TOOLS,
*TIMELINE_TOOLS,
*CLIENT_TOOLS,
*_read_only_memory_tools(user_id, trace_id),
make_query_relations_tool(user_id, trace_id),
]
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
@@ -1216,9 +1415,26 @@ async def run_home_stream(
user_id: str,
message: str,
context: dict[str, Any],
project_id: str | None = None,
) -> AsyncGenerator[tuple[str, Any], None]:
from app.agents.folder_agent import FOLDER_TOOLS
prepared_context = await _prepare_context(message, context)
system_prompt, langfuse_prompt = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, prepared_context)
manifest_block = ""
if project_id:
manifest = await _fetch_project_manifest(project_id)
manifest_block = format_folder_manifest(manifest)
if not manifest_block:
# No specific project context — surface all linked folders so the agent
# can answer questions like "tell me about project X" using its files.
manifest_block = await build_brief_multi_project_manifest()
system_prompt = system_prompt + ("\n\n" + manifest_block if manifest_block else "")
trace_id = _trace_id_from_context(prepared_context)
tools = [*_all_tools_for_user(user_id, trace_id), *FOLDER_TOOLS]
text_chunks: list[str] = []
async for event in _run_single_agent_stream(
user_id=user_id,
@@ -1227,6 +1443,7 @@ async def run_home_stream(
context=prepared_context,
langfuse_prompt=langfuse_prompt,
agent_name="home-agent",
tools=tools,
conversation_history=context.get("conversation_history"),
):
event_type, data = event
@@ -1249,6 +1466,28 @@ async def run_floating_stream(
domain = await _infer_floating_domain(message, prepared_context)
yield "floating_domain", domain
brief_mode: bool = bool(context.get("brief_mode"))
briefing_context_text: str = str(context.get("briefing_context") or "").strip()
if brief_mode and briefing_context_text:
# Stage 2: inject briefing as ground truth context.
# Pre-substitute {briefing_context} in the template (handles both Langfuse {{}} and fallback {})
# before compile_prompt sees the remaining standard variables.
template, langfuse_prompt = get_prompt_or_fallback(
"task_brief_followup_system",
_TASK_BRIEF_FOLLOWUP_SYSTEM_PROMPT,
)
system_prompt = compile_prompt(
template, langfuse_prompt,
date_context=_datetime_context_injection(prepared_context).strip(),
language_instruction=_language_instruction(prepared_context).strip(),
user_identity=_user_identity_injection(prepared_context).strip(),
relational_memory=_relational_memory_injection(prepared_context).strip(),
proactive_hints=_proactive_hints_injection(prepared_context).strip(),
request_context=_request_context_block(prepared_context),
briefing_context=briefing_context_text,
)
else:
system_prompt, langfuse_prompt = _build_system_prompt("floating_system", _FLOATING_SYSTEM_PROMPT, prepared_context)
sanitizer = _FloatingStreamSanitizer()
emitted_sanitized = False
@@ -1283,6 +1522,58 @@ async def run_floating_stream(
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
async def run_task_brief_research_stream(
user_id: str,
task_id: str,
context: dict[str, Any],
project_id: str | None = None,
) -> AsyncGenerator[tuple[str, Any], None]:
"""Stage-1 executive assistant: deep research for one task.
Yields ``("token", chunk)`` events like other stream runners.
The final concatenated text may contain a ``<canvas kind="...">...</canvas>`` block
which the WS handler strips and emits as a ``canvas_draft`` mutation.
"""
from app.agents.folder_agent import FOLDER_TOOLS
prepared_context = await _prepare_context(f"task:{task_id}", context)
tools = [*_brief_research_tools(user_id, _trace_id_from_context(prepared_context)), *FOLDER_TOOLS]
# Inject task_id so the agent knows what to look up first.
research_message = (
f"Prepare a briefing dossier for task ID: {task_id}\n"
"Follow the research workflow: read the task, then project, then client, "
"then cross-project relations, then relevant memory. "
"End with a concrete suggested first step. "
"If this is a writing task, include a <canvas kind=\"...\"> draft."
)
system_prompt, langfuse_prompt = _build_system_prompt(
"task_brief_research_system",
_TASK_BRIEF_RESEARCH_SYSTEM_PROMPT,
prepared_context,
)
manifest_block = ""
if project_id:
manifest = await _fetch_project_manifest(project_id)
manifest_block = format_folder_manifest(manifest)
system_prompt = system_prompt + ("\n\n" + manifest_block if manifest_block else "")
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,
message=research_message,
context=prepared_context,
max_steps=12,
langfuse_prompt=langfuse_prompt,
agent_name="task-brief-agent",
tools=tools,
conversation_history=None,
):
yield event
async def update_core_memory(user_id: str, key: str, value: str) -> None:
"""Compatibility helper kept for callers that expect explicit memory update API."""
async with async_session() as db:

183
app/core/folder_indexer.py Normal file
View File

@@ -0,0 +1,183 @@
"""Per-file summarisation for project folder integration."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass
from langchain_core.messages import HumanMessage, SystemMessage
from pypdf import PdfReader
from docx import Document as DocxDocument
from app.core.langfuse_client import (
compile_prompt,
extract_usage,
get_langfuse,
get_prompt_or_fallback,
)
from app.core.llm import get_llm
_TEXT_FALLBACK = (
"You are summarising a file for an AI assistant that helps the user manage a project.\n"
"Produce a single sentence (<=30 words, <=200 chars) that captures the file's purpose "
"and most important detail.\nFile extension: {ext}\nFile name: {name}\nContent (truncated if long):\n{content}"
)
_IMAGE_FALLBACK = (
"You are summarising an image attached to a project folder.\n"
"Produce a single sentence (<=30 words, <=200 chars) describing what the image shows "
"and any obvious purpose (logo, screenshot, diagram, photo of a whiteboard, etc.)."
)
_MAX_INPUT_CHARS = 6000
@dataclass
class IndexResult:
summary: str
tokens_used: int
async def _llm_text(messages: list) -> object:
"""Make the LLM call for text summarisation.
Defined as a standalone async function so tests can patch it cleanly
without needing to mock the LLM object itself.
"""
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
return await llm.ainvoke(messages)
async def _llm_vision(messages: list) -> object:
"""Make the LLM call for vision (image) summarisation.
Accepts the message list and returns the response directly, mirroring
the ``_llm_text`` caller pattern so tests can patch it at the module level.
"""
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
return await llm.ainvoke(messages)
async def summarize_image(*, image_b64: str, mime: str, file_name: str | None = None) -> IndexResult:
"""Return a compact summary of an image file using vision.
Parameters
----------
image_b64:
Base64-encoded image bytes.
mime:
MIME type of the image, e.g. ``"image/png"``.
file_name:
Optional file name, attached to the Langfuse trace as input metadata.
"""
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_image", _IMAGE_FALLBACK)
messages = [
SystemMessage(content=template),
HumanMessage(content=[
{"type": "text", "text": "Summarise this image."},
{"type": "image_url", "image_url": {"url": f"data:{mime};base64,{image_b64}"}},
]),
]
lf = get_langfuse()
if lf is not None:
with lf.start_as_current_observation(
as_type="generation",
name="folder-summarize-image",
model="gpt-4o-mini",
prompt=prompt_obj,
input={"file_name": file_name, "mime": mime},
) as gen:
response = await _llm_vision(messages)
usage = extract_usage(response)
gen.update(output=response.content, usage_details=usage)
else:
response = await _llm_vision(messages)
usage = extract_usage(response)
summary = (response.content or "").strip()[:500]
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
async def summarize_text(*, content: str, ext: str, name: str) -> IndexResult:
"""Return a compact summary of a text file.
Parameters
----------
content:
Raw text content of the file (will be truncated to _MAX_INPUT_CHARS).
ext:
File extension including the leading dot, e.g. ``".md"``.
name:
File name, e.g. ``"kickoff.md"``.
"""
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_text", _TEXT_FALLBACK)
truncated = content[:_MAX_INPUT_CHARS]
compiled = compile_prompt(template, prompt_obj, ext=ext, name=name, content=truncated)
messages = [
SystemMessage(content=compiled),
HumanMessage(content="Summarise this file."),
]
lf = get_langfuse()
if lf is not None:
with lf.start_as_current_observation(
as_type="generation",
name="folder-summarize-text",
model="gpt-4o-mini",
prompt=prompt_obj,
input={"file_name": name, "ext": ext, "content_chars": len(truncated)},
) as gen:
response = await _llm_text(messages)
usage = extract_usage(response)
gen.update(output=response.content, usage_details=usage)
else:
response = await _llm_text(messages)
usage = extract_usage(response)
summary = (response.content or "").strip()[:500]
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
def _extract_pdf_text(pdf_b64: str) -> str:
buf = io.BytesIO(base64.b64decode(pdf_b64))
reader = PdfReader(buf)
parts: list[str] = []
for page in reader.pages:
try:
parts.append(page.extract_text() or "")
except Exception:
continue
return "\n".join(parts).strip()
def _extract_docx_text(docx_b64: str) -> str:
buf = io.BytesIO(base64.b64decode(docx_b64))
doc = DocxDocument(buf)
return "\n".join(p.text for p in doc.paragraphs if p.text).strip()
async def summarize_pdf(*, pdf_b64: str, name: str) -> IndexResult:
"""Return a compact summary of a PDF file.
Parameters
----------
pdf_b64:
Base64-encoded PDF bytes.
name:
File name, e.g. ``"report.pdf"``.
"""
text = _extract_pdf_text(pdf_b64)
if not text:
return IndexResult(summary="Could not extract text", tokens_used=0)
return await summarize_text(content=text, ext=".pdf", name=name)
async def summarize_docx(*, docx_b64: str, name: str) -> IndexResult:
"""Return a compact summary of a DOCX file.
Parameters
----------
docx_b64:
Base64-encoded DOCX bytes.
name:
File name, e.g. ``"spec.docx"``.
"""
text = _extract_docx_text(docx_b64)
if not text:
return IndexResult(summary="Could not extract text", tokens_used=0)
return await summarize_text(content=text, ext=".docx", name=name)

View File

@@ -107,6 +107,7 @@ _AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
"brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL,
"task-brief-agent": lambda: settings.LLM_MODEL_TASK_BRIEF_AGENT or settings.LLM_MODEL,
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",

View File

@@ -2,11 +2,35 @@
from __future__ import annotations
import re
from collections.abc import AsyncGenerator
from typing import Any
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
# Matches <canvas kind="...">...</canvas> blocks (single-line or multiline).
_CANVAS_BLOCK_RE = re.compile(
r'<canvas\s+kind=["\']([^"\']+)["\']>(.*?)</canvas>',
re.DOTALL | re.IGNORECASE,
)
def extract_canvas_block(text: str) -> tuple[str, str | None, str | None]:
"""Strip the first <canvas kind="...">...</canvas> block from *text*.
Returns ``(visible_text, canvas_content, canvas_kind)``.
``canvas_content`` and ``canvas_kind`` are ``None`` when no block is found.
"""
match = _CANVAS_BLOCK_RE.search(text)
if not match:
return text, None, None
canvas_kind = match.group(1).strip()
canvas_content = match.group(2).strip()
visible = text[: match.start()] + text[match.end() :]
visible = visible.strip()
return visible, canvas_content, canvas_kind
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain

View File

@@ -243,6 +243,7 @@ class AgentRunLog(Base):
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
started_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
@@ -263,6 +264,17 @@ class AgentRunLog(Base):
)
class MonthlyTokenUsage(Base):
__tablename__ = "monthly_token_usage"
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
)
year_month: Mapped[str] = mapped_column(String(7), primary_key=True) # 'YYYY-MM'
feature: Mapped[str] = mapped_column(String(64), primary_key=True)
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
# ── Memory models ─────────────────────────────────────────────────────────────

View File

@@ -87,6 +87,15 @@ class WsFrameType(str, Enum):
journey_reply = "journey_reply"
# ── v5 brief frame types ──────────────────────────────────────────
brief_request = "brief_request"
# ── v6 task brief frame types ─────────────────────────────────────
task_brief_request = "task_brief_request"
# ── v7 folder index frame types ───────────────────────────────────
index_session_start = "index_session_start"
index_file_batch = "index_file_batch"
index_session_cancel = "index_session_cancel"
index_file_result = "index_file_result"
index_session_progress = "index_session_progress"
index_session_done = "index_session_done"
class WsToolCall(BaseModel):
@@ -209,6 +218,7 @@ class WsStreamEnd(BaseModel):
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
request_id: str
error: str | None = None
mutations: list[dict[str, Any]] | None = None
class WsDomain(BaseModel):

View File

@@ -39,3 +39,5 @@ lxml>=5.0.0
PyYAML>=6.0.0
apscheduler>=3.10.0
ruff>=0.8.0
pypdf>=4.0
python-docx>=1.1

View File

@@ -17,6 +17,8 @@ from jose import jwt
from sqlalchemy import StaticPool, event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy import select
from app.config.settings import settings
from app.db import Base, get_session
from app.main import app
@@ -134,6 +136,38 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
# ── Convenience aliases and per-tier user fixtures ────────────────────
@pytest_asyncio.fixture
async def db(db_session: AsyncSession) -> AsyncSession:
"""Alias for db_session — used by folder quota tests."""
return db_session
@pytest_asyncio.fixture
async def test_user_free(db_session: AsyncSession):
"""Return the seeded free-tier User row."""
result = await db_session.execute(
select(User).where(User.id == TEST_USER_IDS["free"])
)
return result.scalar_one()
@pytest_asyncio.fixture
async def test_user_power(db_session: AsyncSession):
"""Return the seeded power-tier User row."""
result = await db_session.execute(
select(User).where(User.id == TEST_USER_IDS["power"])
)
return result.scalar_one()
@pytest.fixture
def auth_headers_free() -> dict[str, str]:
"""Authorization header for the seeded free-tier user."""
return auth_header("free")
# ── CLI options ───────────────────────────────────────────────────────
def pytest_addoption(parser):

View File

@@ -0,0 +1,139 @@
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.agents.folder_agent import (
read_project_folder_file,
search_project_folder_file,
)
pytestmark = pytest.mark.asyncio
async def test_happy_path():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "file body", "kind": "text", "totalSize": 9}),
):
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "docs/x.md"})
assert "file body" in out
assert "kind=text" in out
async def test_traversal_rejected():
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "../../etc/passwd"})
assert out == "Access denied"
async def test_absolute_path_rejected():
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "C:\\Windows\\foo"})
assert out == "Access denied"
async def test_missing_file():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "", "kind": "missing", "totalSize": 0}),
):
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "ghost.md"})
assert "not found" in out.lower()
async def test_pagination_signals_more_available():
# Electron returned the first slice, totalSize larger than slice length.
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "first chunk", "kind": "text", "totalSize": 1000}),
):
out = await read_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "big.txt",
"offset": 0,
"length": 11,
})
assert "first chunk" in out
assert "More content available" in out
assert "offset=11" in out
async def test_pdf_extracted_then_sliced(monkeypatch):
from app.agents import folder_agent
monkeypatch.setattr(folder_agent, "_extract_pdf_text", lambda b: "ABC " * 100)
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "JVBERi0xLg==", "kind": "pdf", "totalSize": 12}),
):
out = await read_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "doc.pdf",
"offset": 0,
"length": 8,
})
assert "kind=pdf" in out
assert "ABC ABC " in out
assert "More content available" in out
async def test_image_returns_placeholder():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "iVBORw0K", "kind": "image", "totalSize": 1024}),
):
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "logo.png"})
assert "image" in out.lower()
async def test_search_finds_match_with_context():
body = "alpha\nbeta\nthe needle is here\ngamma\ndelta"
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": body, "kind": "text", "totalSize": len(body)}),
):
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "log.txt",
"query": "needle",
"context_lines": 1,
})
assert "needle" in out
assert "matches=1" in out
# Context lines included
assert "beta" in out
assert "gamma" in out
async def test_search_no_match():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "nothing here", "kind": "text", "totalSize": 12}),
):
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "x.txt",
"query": "zzz",
})
assert "No matches" in out
async def test_search_rejects_traversal():
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "../etc/passwd",
"query": "root",
})
assert out == "Access denied"
async def test_search_image_rejected():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "b64data", "kind": "image", "totalSize": 100}),
):
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "logo.png",
"query": "anything",
})
assert "Cannot search" in out

View File

@@ -0,0 +1,83 @@
"""Folder indexer LLM helpers."""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.core.folder_indexer import summarize_text, summarize_image, IndexResult
pytestmark = pytest.mark.asyncio
async def test_summarize_text_returns_summary_and_tokens():
mock_resp = AsyncMock()
mock_resp.content = "Kickoff notes covering scope and deadlines."
mock_resp.usage_metadata = {"input_tokens": 320, "output_tokens": 18, "total_tokens": 338}
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
result = await summarize_text(content="hello world", ext=".md", name="kickoff.md")
assert isinstance(result, IndexResult)
assert result.summary == "Kickoff notes covering scope and deadlines."
assert result.tokens_used == 338
async def test_summarize_text_truncates_summary_at_500_chars():
mock_resp = AsyncMock()
mock_resp.content = "x" * 1000
mock_resp.usage_metadata = {"total_tokens": 100}
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
result = await summarize_text(content="x", ext=".md", name="x.md")
assert len(result.summary) <= 500
async def test_summarize_image_uses_vision_content_blocks():
mock_resp = AsyncMock()
mock_resp.content = "Final logo on white background."
mock_resp.usage_metadata = {"total_tokens": 500}
captured = {}
async def fake_llm_vision(messages):
captured["messages"] = messages
return mock_resp
with patch("app.core.folder_indexer._llm_vision", new=fake_llm_vision):
result = await summarize_image(image_b64="iVBORw0KG", mime="image/png")
assert "Final logo" in result.summary
assert result.tokens_used == 500
# last message contains an image content block
last = captured["messages"][-1]
assert any(
isinstance(p, dict) and p.get("type") == "image_url"
for p in (last.content if isinstance(last.content, list) else [])
)
async def test_summarize_pdf_extracts_then_summarizes(monkeypatch):
# pypdf.PdfReader returns text from pages
from app.core import folder_indexer
class FakePage:
def extract_text(self): return "PDF page content with project info."
class FakeReader:
pages = [FakePage(), FakePage()]
monkeypatch.setattr(folder_indexer, "PdfReader", lambda buf: FakeReader())
mock_resp = AsyncMock(); mock_resp.content = "Project info doc."; mock_resp.usage_metadata = {"total_tokens": 50}
async def fake_llm(messages): return mock_resp
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
result = await folder_indexer.summarize_pdf(pdf_b64="SGVsbG8=", name="doc.pdf")
assert "Project info" in result.summary
assert result.tokens_used == 50
async def test_summarize_docx_extracts_then_summarizes(monkeypatch):
from app.core import folder_indexer
class FakePara:
def __init__(self, t): self.text = t
class FakeDoc:
paragraphs = [FakePara("Heading"), FakePara("Body paragraph one.")]
monkeypatch.setattr(folder_indexer, "DocxDocument", lambda buf: FakeDoc())
mock_resp = AsyncMock(); mock_resp.content = "Heading and body."; mock_resp.usage_metadata = {"total_tokens": 30}
async def fake_llm(messages): return mock_resp
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
result = await folder_indexer.summarize_docx(docx_b64="UEsDBBQ=", name="doc.docx")
assert result.summary == "Heading and body."

View File

@@ -0,0 +1,94 @@
"""Folder quota helpers."""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from sqlalchemy import select
from app.billing.quota import (
check_folder_quota,
add_token_usage,
QuotaExceeded,
)
from app.models import MonthlyTokenUsage
pytestmark = pytest.mark.asyncio
async def test_check_folder_quota_free_rejects_above_file_cap(db, test_user_free):
with pytest.raises(QuotaExceeded) as exc:
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=500, db=db
)
assert exc.value.reason == "max_files"
async def test_check_folder_quota_free_passes_under_cap(db, test_user_free):
# No raise
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=50, db=db
)
async def test_check_folder_quota_rejects_when_monthly_exhausted(db, test_user_free):
ym = datetime.now(timezone.utc).strftime("%Y-%m")
db.add(MonthlyTokenUsage(
user_id=test_user_free.id, year_month=ym, feature="folder_index", tokens_used=100_000
))
await db.commit()
with pytest.raises(QuotaExceeded) as exc:
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=10, db=db
)
assert exc.value.reason == "monthly_tokens"
async def test_check_folder_quota_power_unlimited(db, test_user_power):
await check_folder_quota(
user_id=test_user_power.id, tier="power", estimated_files=999_999, db=db
)
async def test_add_token_usage_atomic_increment(db, test_user_free):
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=1500, db=db)
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=2500, db=db)
ym = datetime.now(timezone.utc).strftime("%Y-%m")
row = (await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == test_user_free.id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == "folder_index",
)
)).scalar_one()
assert row.tokens_used == 4000
async def test_add_token_usage_returns_exhausted_when_over_cap(db, test_user_free):
result = await add_token_usage(
user_id=test_user_free.id, feature="folder_index", tokens=150_000, db=db, cap=100_000
)
assert result.exhausted is True
assert result.tokens_used == 150_000
def test_quota_check_endpoint_rejects(client, auth_headers_free):
res = client.post(
"/api/v1/billing/quota/check",
json={"feature": "folder_index", "estimated_files": 500},
headers=auth_headers_free,
)
assert res.status_code == 402
body = res.json()
assert body["detail"]["reason"] == "max_files"
def test_quota_check_endpoint_passes(client, auth_headers_free):
res = client.post(
"/api/v1/billing/quota/check",
json={"feature": "folder_index", "estimated_files": 50},
headers=auth_headers_free,
)
assert res.status_code == 200
assert res.json() == {"ok": True}

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.core.deep_agent import format_folder_manifest, MANIFEST_TOKEN_BUDGET
pytestmark = pytest.mark.asyncio
def test_format_folder_manifest_basic():
manifest = {
"folderPath": "D:\\Acme",
"lastScannedAt": "2h ago",
"files": [
{"relPath": "briefs/kickoff.md", "kind": "text", "summary": "Kickoff notes; scope and deadlines."},
{"relPath": "logos/logo-v3.png", "kind": "image", "summary": "Final logo on white."},
],
}
out = format_folder_manifest(manifest)
assert "<linked_folder>" in out
assert "/briefs/kickoff.md" in out or "briefs/kickoff.md" in out
assert "[text]" in out
assert "[image]" in out
def test_format_folder_manifest_truncates_past_budget():
files = [
{"relPath": f"f{i}.md", "kind": "text", "summary": "x" * 100, "mtimeMs": i}
for i in range(2000)
]
out = format_folder_manifest({"folderPath": "p", "lastScannedAt": "now", "files": files})
assert "more files omitted" in out
# Rough token check
assert len(out) // 4 < MANIFEST_TOKEN_BUDGET + 200
def test_format_folder_manifest_null_returns_empty():
assert format_folder_manifest(None) == ""
assert format_folder_manifest({"files": []}) == ""
async def test_brief_multi_project_manifest_top_5_per_project():
fake_response = [
{
"projectId": "p1", "projectName": "Acme", "folderPath": "/a",
"lastScannedAt": "now",
"files": [
{"relPath": f"f{i}.md", "kind": "text", "summary": "s", "mtimeMs": i}
for i in range(10)
],
},
{
"projectId": "p2", "projectName": "Beta", "folderPath": "/b",
"lastScannedAt": "now",
"files": [{"relPath": "x.md", "kind": "text", "summary": "s", "mtimeMs": 1}],
},
]
with patch(
"app.core.deep_agent.execute_on_client",
new=AsyncMock(return_value={"projects": fake_response}),
):
from app.core.deep_agent import build_brief_multi_project_manifest
out = await build_brief_multi_project_manifest()
# Project 1 has 10 files, only top 5 by mtimeMs should appear
assert out.count("[p1]") <= 5
# Project 2 has 1 file, must appear
assert "[p2]" in out or "Beta" in out

View File

@@ -0,0 +1,196 @@
"""Tests for WS folder index_session handlers (Task 9).
Tests the three handler functions directly with a minimal fake WebSocket so
no real WS connection or LLM call is made.
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
import pytest_asyncio
from app.api.routes.device_ws import (
_handle_index_session_start,
_handle_index_file_batch,
_handle_index_session_cancel,
_index_sessions,
)
from app.billing.quota import add_token_usage
from app.core.folder_indexer import IndexResult
from app.models import MonthlyTokenUsage
from app.schemas import WsFrameType
from tests.conftest import TEST_USER_IDS
pytestmark = pytest.mark.asyncio
USER_ID = TEST_USER_IDS["free"]
POWER_USER_ID = TEST_USER_IDS["power"]
# ── Fake WebSocket ────────────────────────────────────────────────────
class _FakeWebSocket:
"""Minimal WebSocket stand-in that records send_text calls."""
def __init__(self) -> None:
self.sent: list[dict] = []
async def send_text(self, text: str) -> None:
self.sent.append(json.loads(text))
def sent_types(self) -> list[str]:
return [f["type"] for f in self.sent]
# ── Helpers ───────────────────────────────────────────────────────────
def _make_session_id() -> str:
import uuid
return str(uuid.uuid4())
def _fake_summarize_text_factory(summary: str = "A test summary.", tokens: int = 100):
"""Return an AsyncMock that resolves to a fixed IndexResult."""
async def _fake(**kwargs) -> IndexResult:
return IndexResult(summary=summary, tokens_used=tokens)
return _fake
# ── Fixtures ──────────────────────────────────────────────────────────
@pytest_asyncio.fixture(autouse=True)
async def _clean_sessions():
"""Ensure _index_sessions is empty before and after each test."""
_index_sessions.clear()
yield
_index_sessions.clear()
# ── Tests ─────────────────────────────────────────────────────────────
async def test_index_session_happy_path(db_session):
"""start + batch of 2 text files → 2 index_file_result + 1 progress + 1 done(completed)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
# Register session.
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"projectId": "proj-1",
"totalFiles": 2,
})
# Verify session was registered.
assert session_id in _index_sessions
assert _index_sessions[session_id]["total"] == 2
assert _index_sessions[session_id]["processed"] == 0
# No response frames expected for session_start.
assert ws.sent == []
# Send batch of 2 text files — patch summarize_text so no LLM call needed.
with patch(
"app.api.routes.device_ws._handle_index_file_batch.__globals__",
# We patch the module-level function in folder_indexer instead:
) if False else patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory()):
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
# Wire db_session into the context manager.
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_async_session.return_value = mock_cm
await _handle_index_file_batch(ws, USER_ID, {
"sessionId": session_id,
"files": [
{"relPath": "README.md", "kind": "text", "content": "hello", "ext": ".md"},
{"relPath": "notes.txt", "kind": "text", "content": "world", "ext": ".txt"},
],
})
types = ws.sent_types()
# Expect 2 file results + 1 progress + 1 done(completed).
assert types.count(WsFrameType.index_file_result) == 2
assert types.count(WsFrameType.index_session_progress) == 1
assert types.count(WsFrameType.index_session_done) == 1
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "completed"
progress_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_progress)
assert progress_frame["processed"] == 2
assert progress_frame["total"] == 2
# Verify session cleaned up.
assert session_id not in _index_sessions
async def test_index_session_cancel(db_session):
"""start then cancel → index_session_done(cancelled)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"totalFiles": 5,
})
assert session_id in _index_sessions
await _handle_index_session_cancel(ws, {"sessionId": session_id})
types = ws.sent_types()
assert WsFrameType.index_session_done in types
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "cancelled"
# Session should be cleaned up.
assert session_id not in _index_sessions
async def test_index_session_quota_exceeded(db_session):
"""Pre-fill usage to cap → batch one file → index_session_done(quota_exceeded)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
# Pre-fill monthly token usage to the free-tier cap (100_000).
ym = datetime.now(timezone.utc).strftime("%Y-%m")
db_session.add(MonthlyTokenUsage(
user_id=USER_ID,
year_month=ym,
feature="folder_index",
tokens_used=100_000, # free tier cap exactly
))
await db_session.commit()
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"totalFiles": 1,
})
with patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory(tokens=1)):
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_async_session.return_value = mock_cm
await _handle_index_file_batch(ws, USER_ID, {
"sessionId": session_id,
"files": [
{"relPath": "file.md", "kind": "text", "content": "content", "ext": ".md"},
],
})
types = ws.sent_types()
# Should have 1 file result (success) then done(quota_exceeded).
assert WsFrameType.index_file_result in types
assert WsFrameType.index_session_done in types
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "quota_exceeded"
# Session should be cleaned up.
assert session_id not in _index_sessions