20 Commits

Author SHA1 Message Date
Roberto
cc0e258e8c fix(api): WS index frames accept both camelCase and snake_case keys (Electron toSnakeCase compat) 2026-05-13 08:58:46 +02:00
Roberto
12e203e63d fix(api): multi-project manifest lists projects even with zero indexed files 2026-05-12 18:10:57 +02:00
Roberto
ffcd7390f0 feat(api): pagination + search + PDF/DOCX extract in folder agent tools 2026-05-12 17:31:43 +02:00
Roberto
91e880f9d4 fix(api): home agent falls back to multi-project folder manifest when no project_id 2026-05-12 16:54:47 +02:00
Roberto
7d47ca54be feat(api): emit Langfuse generation traces for folder indexer 2026-05-12 16:40:20 +02:00
Roberto
956fa88853 feat(api): multi-project folder manifest for daily brief
Add build_brief_multi_project_manifest() to deep_agent.py that fetches
all project folder manifests via execute_on_client and keeps the top 5
most-recently-modified files per project. Wire into run_home_brief in
brief_agent.py, injecting the <linked_folders> block into the system
prompt alongside FOLDER_TOOLS.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:40:47 +02:00
Roberto
fb2f59ccea feat(api): inject folder manifest into home agent when project context active
Add optional project_id param to run_home_stream. When set, fetch the linked
folder manifest via _fetch_project_manifest and prepend the <linked_folder>
block to the system prompt. Also build an explicit tools list that extends
_all_tools_for_user with FOLDER_TOOLS so the home agent can read folder
files. device_ws._handle_home_request extracts project_id / projectId from
the home_request frame and forwards it to the runner.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:32:20 +02:00
Roberto
56dbb7f4cd feat(api): inject folder manifest into task brief agent
Add _fetch_project_manifest helper that calls read_project_folder_manifest
via execute_on_client. Wire it into run_task_brief_research_stream (new
optional project_id param) so the <linked_folder> block is prepended to the
system prompt when the task belongs to a linked project. Also bind
FOLDER_TOOLS into the task-brief tool palette so the agent can read folder
files. device_ws extracts project_id / projectId from the task_brief_request
frame and forwards it.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:31:21 +02:00
Roberto
506f517851 feat(api): manifest formatter with token-budget truncation 2026-05-12 11:28:13 +02:00
Roberto
520c186991 feat(api): scoped read_project_folder_file tool with traversal guard
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:26:02 +02:00
Roberto
582bf27deb feat(api): WS index_session frames + handlers
Add six v7 WsFrameType enum members (index_session_start/cancel/batch,
index_file_result/progress/done), wire dispatch in device_ws message loop,
and implement _handle_index_session_start/cancel/file_batch with per-file
summarisation, token accounting, and quota enforcement.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:22:20 +02:00
Roberto
2aeb453229 feat(api): PDF + DOCX extraction in folder indexer
Add pypdf/python-docx deps, _extract_pdf_text/_extract_docx_text helpers,
and summarize_pdf/summarize_docx wrappers that delegate to summarize_text.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:15:17 +02:00
Roberto
b7a4edac90 feat(api): folder_indexer.summarize_image via gpt-4o-mini vision
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:09:37 +02:00
Roberto
822b4cd8b1 feat(api): folder_indexer.summarize_text via gpt-4o-mini 2026-05-12 11:05:43 +02:00
Roberto
ab24fc4c91 feat(api): POST /billing/quota/check endpoint
Pre-flight quota check for folder_index. Returns 402 with reason
when file cap or monthly token budget would be exceeded; 200 {"ok": true}
otherwise. Also adds auth_headers_free fixture to conftest.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 09:14:56 +02:00
Roberto
a98e99f7a2 feat(api): folder quota helpers with atomic token usage
Implements check_folder_quota and add_token_usage in app/billing/quota.py
with dialect-aware upsert (pg_insert on PostgreSQL, read-then-write on SQLite).
Adds test_user_free/test_user_power fixtures and db alias to conftest.py.
6 new tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 08:23:22 +02:00
Roberto
a0ff285bcd feat(api): tier features for folder integration
Add folder_max_files and folder_monthly_tokens to all four tier dicts
in FEATURES, and add get_feature_value() helper to TierManager.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:39:36 +02:00
Roberto
177c1a87dd feat(api): MonthlyTokenUsage model + AgentRunLog.tokens_used
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:30:33 +02:00
Roberto
441a4ea05c chore(api): fix stale Revises comment in folder migration 2026-05-12 07:21:13 +02:00
Roberto
a693a64bf5 feat(api): add migration for folder token tracking
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:16:23 +02:00
18 changed files with 1537 additions and 8 deletions

View File

@@ -0,0 +1,46 @@
"""Add token tracking columns for folder integration.
Revision ID: d6e3f4a5b6c7
Revises: 006
Create Date: 2026-05-11 00:00:00.000000
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = "d6e3f4a5b6c7"
down_revision: Union[str, None] = "006"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"agent_run_logs",
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
)
op.create_table(
"monthly_token_usage",
sa.Column("user_id", UUID(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("year_month", sa.String(7), nullable=False),
sa.Column("feature", sa.String(64), nullable=False),
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
sa.PrimaryKeyConstraint("user_id", "year_month", "feature"),
)
op.create_index(
"ix_monthly_token_usage_user_month",
"monthly_token_usage",
["user_id", "year_month"],
)
def downgrade() -> None:
op.drop_index("ix_monthly_token_usage_user_month", table_name="monthly_token_usage")
op.drop_table("monthly_token_usage")
op.drop_column("agent_run_logs", "tokens_used")

168
app/agents/folder_agent.py Normal file
View File

@@ -0,0 +1,168 @@
"""Scoped file-read and search tools for the project folder feature."""
from __future__ import annotations
from langchain_core.tools import tool
from app.core.folder_indexer import _extract_docx_text, _extract_pdf_text
from app.core.ws_context import execute_on_client
# Cap returned slice size to keep tool output under control.
_MAX_RETURN_CHARS = 50_000
_MAX_SEARCH_MATCHES = 20
def _is_unsafe_path(rel: str) -> bool:
if not rel:
return True
norm = rel.replace("\\", "/")
if norm.startswith("/"):
return True
# Windows drive letter
if len(rel) >= 2 and rel[1] == ":":
return True
parts = norm.split("/")
return ".." in parts
async def _fetch_file(project_id: str, relative_path: str, offset: int, length: int) -> dict:
"""Return the raw Electron tool_result dict for a file read."""
return await execute_on_client(
action="read_project_folder_file",
data={
"projectId": project_id,
"relativePath": relative_path,
"offset": offset,
"length": length,
},
)
def _decode(result: dict) -> tuple[str, str, int]:
"""Decode a tool_result into (text, kind, total_size). For pdf/docx,
extracts text from base64. For images, returns a placeholder string.
For text, content is already a sliced utf-8 string.
"""
kind = result.get("kind", "text")
content = result.get("content", "") or ""
total = int(result.get("totalSize", 0) or 0)
if kind == "image":
return ("[Image file — cannot be navigated as text. See manifest summary.]", kind, total)
if kind == "pdf":
return (_extract_pdf_text(content), kind, total)
if kind == "docx":
return (_extract_docx_text(content), kind, total)
return (content, kind, total)
@tool
async def read_project_folder_file(
project_id: str,
relative_path: str,
offset: int = 0,
length: int = _MAX_RETURN_CHARS,
) -> str:
"""Read a slice of a file inside the project's linked folder.
Args:
project_id: project ID.
relative_path: path relative to the linked folder root.
offset: char offset to start reading from (0 = beginning).
length: max chars to return. Default 50000. Use smaller values to save tokens.
Returns text content slice with a header showing position. Header tells you
when more content is available; call again with the suggested next offset.
For PDF / DOCX files the backend extracts text first, then applies offset/length
on the extracted text. For images returns a placeholder; navigate with the
manifest summary instead.
"""
if _is_unsafe_path(relative_path):
return "Access denied"
result = await _fetch_file(project_id, relative_path, offset, length)
text, kind, total_size = _decode(result)
if not text and kind in ("missing", "error"):
return f"File not found or unreadable: {relative_path}"
if kind in ("pdf", "docx"):
# Backend extracted full text — apply offset/length on chars.
sliced = text[offset:offset + length]
slice_end = min(offset + length, len(text))
header = (
f"[file={relative_path} kind={kind} offset={offset} end={slice_end} "
f"totalChars={len(text)}]"
)
if slice_end < len(text):
header += f"\n[More content available — call again with offset={slice_end}.]"
return header + "\n" + sliced
if kind == "text":
slice_end = offset + len(text)
header = (
f"[file={relative_path} kind=text offset={offset} end={slice_end} "
f"totalBytes={total_size}]"
)
if slice_end < total_size:
header += f"\n[More content available — call again with offset={slice_end}.]"
return header + "\n" + text
# image or unknown
return text
@tool
async def search_project_folder_file(
project_id: str,
relative_path: str,
query: str,
context_lines: int = 3,
) -> str:
"""Search a project folder file for a query string (case-insensitive substring).
Args:
project_id: project ID.
relative_path: path relative to the linked folder root.
query: text to search for.
context_lines: number of lines of context around each match (default 3).
Returns matching line ranges with surrounding context and 1-based line numbers.
Capped at 20 matches; if more exist the header shows the total.
Works on text, code, markdown, PDF (extracted), and DOCX (extracted).
Images and binary files are not searchable.
"""
if _is_unsafe_path(relative_path):
return "Access denied"
if not query:
return "Empty query."
# For text we still need full file; pass length=very large.
result = await _fetch_file(project_id, relative_path, offset=0, length=10_000_000)
text, kind, _ = _decode(result)
if not text and kind in ("missing", "error"):
return f"File not found or unreadable: {relative_path}"
if kind == "image":
return "Cannot search inside images."
lines = text.splitlines()
q = query.lower()
matches = [i for i, line in enumerate(lines) if q in line.lower()]
if not matches:
return f"No matches for '{query}' in {relative_path}."
shown = matches[:_MAX_SEARCH_MATCHES]
snippets: list[str] = []
for i in shown:
start = max(0, i - context_lines)
end = min(len(lines), i + context_lines + 1)
block = "\n".join(f"{n + 1:5d}: {lines[n]}" for n in range(start, end))
snippets.append(block)
header = f"[file={relative_path} matches={len(matches)} showing={len(shown)} query='{query}']"
body = "\n---\n".join(snippets)
return header + "\n" + body
FOLDER_TOOLS = [read_project_folder_file, search_project_folder_file]

View File

@@ -9,7 +9,7 @@ from __future__ import annotations
from typing import Any
from fastapi import APIRouter, Depends, Header, Request, status
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
@@ -96,3 +96,37 @@ async def list_invoices(
"""
invoices = await stripe_service.list_invoices(current_user.id, db)
return invoices
# ── Quota check ────────────────────────────────────────────────────────
from app.billing.quota import check_folder_quota, QuotaExceeded # noqa: E402
class QuotaCheckRequest(BaseModel):
feature: str
estimated_files: int
@router.post("/quota/check")
async def quota_check(
payload: QuotaCheckRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict:
"""Pre-flight folder quota check. 402 if tier limits would be exceeded."""
if payload.feature != "folder_index":
raise HTTPException(status_code=400, detail="Unknown feature")
try:
await check_folder_quota(
user_id=current_user.id,
tier=current_user.tier,
estimated_files=payload.estimated_files,
db=db,
)
except QuotaExceeded as exc:
raise HTTPException(
status_code=402,
detail={"reason": exc.reason, "message": str(exc)},
)
return {"ok": True}

View File

@@ -57,6 +57,10 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ws", tags=["device-ws"])
# ── v7 folder index session state ─────────────────────────────────────
# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled }
_index_sessions: dict[str, dict] = {}
_HEARTBEAT_INTERVAL = 30 # seconds
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
@@ -180,6 +184,19 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
_handle_journey_message(websocket, user_id, frame)
)
elif frame_type == WsFrameType.index_session_start:
asyncio.create_task(
_handle_index_session_start(websocket, user_id, frame)
)
elif frame_type == WsFrameType.index_file_batch:
asyncio.create_task(
_handle_index_file_batch(websocket, user_id, frame)
)
elif frame_type == WsFrameType.index_session_cancel:
await _handle_index_session_cancel(websocket, frame)
elif frame_type == "pong":
# Heartbeat ack — nothing to do, connection is alive.
pass
@@ -211,11 +228,13 @@ async def _handle_home_request(
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
logger.info(
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
"device_ws: home_request_start user=%s req=%s session=%s project=%s msg=%s",
user_id,
request_id,
session_id,
project_id,
message[:200],
)
@@ -240,7 +259,7 @@ async def _handle_home_request(
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_home_stream(user_id, message, context)
event_stream = run_home_stream(user_id, message, context, project_id=project_id)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
@@ -437,10 +456,11 @@ async def _handle_task_brief_request(
request_id = frame.get("request_id") or str(uuid4())
session_id = frame.get("session_id") or str(uuid4())
task_id: str = frame.get("task_id") or frame.get("taskId") or ""
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
logger.info(
"device_ws: task_brief_request_start user=%s req=%s task=%s [cache_miss]",
user_id, request_id, task_id,
"device_ws: task_brief_request_start user=%s req=%s task=%s project=%s [cache_miss]",
user_id, request_id, task_id, project_id,
)
if not task_id:
@@ -469,7 +489,7 @@ async def _handle_task_brief_request(
response_chunks: list[str] = []
try:
event_stream = run_task_brief_research_stream(user_id, task_id, context)
event_stream = run_task_brief_research_stream(user_id, task_id, context, project_id=project_id)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
if ws_frame.type == "stream_text": # type: ignore[union-attr]
@@ -569,6 +589,174 @@ async def _handle_journey_message(
clear_client_executor()
# ── v7 Folder Index Handlers ──────────────────────────────────────────
async def _handle_index_session_start(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Register a new folder index session. No response sent — client is declaring intent."""
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
project_id: str | None = frame.get("projectId") or frame.get("project_id")
total: int = int(frame.get("totalFiles") or frame.get("total_files") or 0)
if not session_id:
logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id)
return
_index_sessions[session_id] = {
"user_id": user_id,
"project_id": project_id,
"processed": 0,
"total": total,
"cancelled": False,
}
logger.info(
"device_ws: index_session_start user=%s session=%s project=%s total=%d",
user_id, session_id, project_id, total,
)
async def _handle_index_session_cancel(
websocket: WebSocket,
frame: dict,
) -> None:
"""Mark a session as cancelled and emit index_session_done(cancelled)."""
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
session = _index_sessions.get(session_id)
if session:
session["cancelled"] = True
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "cancelled",
}))
_index_sessions.pop(session_id, None)
logger.info("device_ws: index_session_cancel session=%s", session_id)
async def _handle_index_file_batch(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Process a batch of files for an index session, streaming results back."""
# Lazy imports to avoid heavy load at module startup.
from app.core.folder_indexer import ( # noqa: PLC0415
summarize_image,
summarize_pdf,
summarize_docx,
summarize_text,
)
from app.billing.tier_manager import tier_manager # noqa: PLC0415
from app.billing.quota import add_token_usage # noqa: PLC0415
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
files: list[dict] = frame.get("files", [])
session = _index_sessions.get(session_id)
if not session or session.get("cancelled"):
return
async with async_session() as db:
tier = await tier_manager.get_tier(user_id, db)
raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens")
cap: int | None = None if raw_cap == -1 else raw_cap
for file_info in files:
if session.get("cancelled"):
return
# Electron's toSnakeCase converts payload keys, so accept both forms.
rel_path: str = file_info.get("relPath") or file_info.get("rel_path") or ""
kind: str = file_info.get("kind") or "text"
content: str = file_info.get("content") or ""
ext: str = file_info.get("ext") or ""
mime: str = file_info.get("mime") or "application/octet-stream"
name: str = rel_path.split("/")[-1] or rel_path
try:
if kind == "image":
res = await summarize_image(image_b64=content, mime=mime)
elif kind == "pdf":
res = await summarize_pdf(pdf_b64=content, name=name)
elif kind == "docx":
res = await summarize_docx(docx_b64=content, name=name)
else:
res = await summarize_text(content=content, ext=ext, name=name)
except Exception as exc:
logger.warning(
"device_ws: index_file_batch summarize failed session=%s path=%s: %s",
session_id, rel_path, exc,
)
await websocket.send_text(json.dumps({
"type": WsFrameType.index_file_result,
"sessionId": session_id,
"relPath": rel_path,
"summary": None,
"tokensUsed": 0,
"error": str(exc),
}))
session["processed"] += 1
continue
# Account for token usage and check cap.
usage = await add_token_usage(
user_id=user_id,
feature="folder_index",
tokens=res.tokens_used,
db=db,
cap=cap,
)
await websocket.send_text(json.dumps({
"type": WsFrameType.index_file_result,
"sessionId": session_id,
"relPath": rel_path,
"summary": res.summary,
"tokensUsed": res.tokens_used,
}))
session["processed"] += 1
if usage.exhausted:
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "quota_exceeded",
}))
_index_sessions.pop(session_id, None)
logger.info(
"device_ws: index_session quota_exceeded user=%s session=%s",
user_id, session_id,
)
return
# After processing the batch, emit progress.
processed = session["processed"]
total = session["total"]
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_progress,
"sessionId": session_id,
"processed": processed,
"total": total,
}))
if processed >= total:
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "completed",
}))
_index_sessions.pop(session_id, None)
logger.info(
"device_ws: index_session_done completed user=%s session=%s processed=%d",
user_id, session_id, processed,
)
# ── Heartbeat ─────────────────────────────────────────────────────────
async def _heartbeat_loop(websocket: WebSocket) -> None:

139
app/billing/quota.py Normal file
View File

@@ -0,0 +1,139 @@
"""Quota checks and atomic token-usage accounting for folder integration."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from sqlalchemy import select, update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from app.billing.tier_manager import TierManager
from app.models import MonthlyTokenUsage
from app.schemas import BillingTier
class QuotaExceeded(Exception):
"""Raised when a folder operation cannot proceed under the user's tier."""
def __init__(self, reason: str, message: str) -> None:
super().__init__(message)
self.reason = reason # "max_files" | "monthly_tokens"
@dataclass
class TokenUsageResult:
tokens_used: int
exhausted: bool
def _current_year_month() -> str:
return datetime.now(timezone.utc).strftime("%Y-%m")
_tier_manager = TierManager()
async def check_folder_quota(
*,
user_id: str,
tier: BillingTier,
estimated_files: int,
db: AsyncSession,
) -> None:
"""Raise QuotaExceeded if folder_max_files or folder_monthly_tokens
would be violated. -1 in either feature means unlimited."""
max_files = _tier_manager.get_feature_value(tier, "folder_max_files")
if max_files != -1 and estimated_files > max_files:
raise QuotaExceeded(
"max_files",
f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.",
)
cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens")
if cap == -1:
return
ym = _current_year_month()
row = (
await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == user_id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == "folder_index",
)
)
).scalar_one_or_none()
used = row.tokens_used if row else 0
if used >= cap:
raise QuotaExceeded(
"monthly_tokens",
f"Monthly token budget exhausted ({used}/{cap}); resets next month.",
)
async def add_token_usage(
*,
user_id: str,
feature: str,
tokens: int,
db: AsyncSession,
cap: int | None = None,
) -> TokenUsageResult:
"""Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature).
Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls
back to a read-then-write on other engines (e.g. aiosqlite in tests).
Returns post-update total and whether cap is exhausted.
"""
ym = _current_year_month()
# Detect dialect to choose between native upsert and portable fallback.
dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr]
if dialect_name == "postgresql":
# Native atomic upsert — production path.
stmt = (
pg_insert(MonthlyTokenUsage)
.values(
user_id=user_id,
year_month=ym,
feature=feature,
tokens_used=tokens,
)
.on_conflict_do_update(
index_elements=["user_id", "year_month", "feature"],
set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens},
)
.returning(MonthlyTokenUsage.tokens_used)
)
used: int = (await db.execute(stmt)).scalar_one()
await db.commit()
else:
# Portable fallback — used in tests (SQLite) and any non-PG engine.
row = (
await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == user_id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == feature,
)
)
).scalar_one_or_none()
if row is None:
row = MonthlyTokenUsage(
user_id=user_id,
year_month=ym,
feature=feature,
tokens_used=tokens,
)
db.add(row)
else:
row.tokens_used += tokens
await db.commit()
await db.refresh(row)
used = row.tokens_used
exhausted = cap is not None and cap != -1 and used >= cap
return TokenUsageResult(tokens_used=used, exhausted=exhausted)

View File

@@ -29,6 +29,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": False, # batch queue (Phase 2)
"relational_memory": False, # relational tier (Phase 3) — Pro+
"proactive_mining": False, # Power+ only (Phase 5)
"folder_max_files": 200,
"folder_monthly_tokens": 100_000,
},
"pro": {
"agents": -1, # unlimited
@@ -41,6 +43,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": True, # fire-and-forget asyncio.create_task
"relational_memory": True, # person/project predicates
"proactive_mining": False, # Power+ only (Phase 5)
"folder_max_files": 5000,
"folder_monthly_tokens": 2_000_000,
},
"power": {
"agents": -1,
@@ -53,6 +57,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": True,
"relational_memory": True, # all predicates incl. custom
"proactive_mining": True, # scheduled pattern mining (Phase 5)
"folder_max_files": -1, # unlimited
"folder_monthly_tokens": -1, # unlimited
},
"team": {
"agents": -1,
@@ -65,6 +71,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": True,
"relational_memory": True, # all predicates incl. custom
"proactive_mining": True, # scheduled pattern mining (Phase 5)
"folder_max_files": -1, # unlimited
"folder_monthly_tokens": -1, # unlimited
},
}
@@ -123,6 +131,13 @@ class TierManager:
)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def get_feature_value(self, tier: BillingTier, feature: str) -> int:
"""Return integer feature value for tier. -1 means unlimited."""
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
if not isinstance(value, int):
return 0
return value
# ── Rate limiting ────────────────────────────────────────────────────
def get_rate_limit(self, tier: BillingTier) -> int:

View File

@@ -21,6 +21,7 @@ from app.core.deep_agent import (
_relational_memory_injection,
_run_single_agent_stream,
_trace_id_from_context,
build_brief_multi_project_manifest,
)
from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback
@@ -159,6 +160,8 @@ async def run_home_brief(
Yields (event_type, data) tuples identical to _run_single_agent_stream.
Do NOT post-process output through _normalize_tagged_list_lines.
"""
from app.agents.folder_agent import FOLDER_TOOLS
trace_id = _trace_id_from_context(context)
today = date.today().isoformat()
language = _resolve_language(context)
@@ -171,7 +174,10 @@ async def run_home_brief(
if today not in system_prompt:
system_prompt += f"\nToday is {today}."
tools = _build_read_tools(user_id, trace_id)
brief_manifest = await build_brief_multi_project_manifest()
system_prompt = system_prompt + ("\n\n" + brief_manifest if brief_manifest else "")
tools = [*_build_read_tools(user_id, trace_id), *FOLDER_TOOLS]
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,

View File

@@ -60,6 +60,93 @@ def _language_instruction(context: dict[str, Any]) -> str:
f"All your output text must be written in {lang}."
)
MANIFEST_TOKEN_BUDGET = 3000 # rough budget for <linked_folder> block
def format_folder_manifest(manifest: dict | None) -> str:
"""Format a folder manifest into the <linked_folder> block.
Truncates by mtime DESC if estimated tokens exceed MANIFEST_TOKEN_BUDGET.
Returns empty string if manifest is None or has no files.
"""
if not manifest or not manifest.get("files"):
return ""
files = list(manifest["files"])
files.sort(key=lambda f: f.get("mtimeMs", 0), reverse=True)
header = (
f"<linked_folder>\npath: {manifest.get('folderPath', '?')} "
f"({len(files)} files, scanned {manifest.get('lastScannedAt', '?')})\nfiles:\n"
)
footer_template = "{} more files omitted, use read_project_folder_file to access by path\n</linked_folder>"
char_budget = MANIFEST_TOKEN_BUDGET * 4 # ~4 chars/token
body = ""
included = 0
for f in files:
line = f"- /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}\n"
if len(header) + len(body) + len(line) + len(footer_template.format(0)) > char_budget:
break
body += line
included += 1
omitted = len(files) - included
if omitted > 0:
return header + body + footer_template.format(omitted)
return header + body + "</linked_folder>"
async def _fetch_project_manifest(project_id: str) -> dict | None:
"""Fetch manifest from Electron via execute_on_client. Returns None if unlinked or error."""
from app.core.ws_context import execute_on_client
try:
result = await execute_on_client(
action="read_project_folder_manifest",
data={"projectId": project_id},
)
if not result or not result.get("folderPath"):
return None
return result
except Exception:
return None
async def build_brief_multi_project_manifest() -> str:
"""Build a compact multi-project manifest for the daily brief agent.
Calls execute_on_client('list_projects_with_folder_manifests') and keeps
the top 5 most-recently-modified files per project.
"""
try:
result = await execute_on_client(
action="list_projects_with_folder_manifests",
data={},
)
except Exception:
return ""
projects = (result or {}).get("projects") or []
if not projects:
return ""
blocks: list[str] = ["<linked_folders>"]
any_entry = False
for p in projects:
all_files = p.get("files", []) or []
files = sorted(all_files, key=lambda f: f.get("mtimeMs", 0), reverse=True)[:5]
blocks.append(f"project: {p.get('projectName','?')} [{p.get('projectId','?')}]")
blocks.append(f" path: {p.get('folderPath','?')} (scanned {p.get('lastScannedAt','?')})")
if not all_files:
blocks.append(" (no indexed files yet — folder is linked but empty or unscanned)")
else:
for f in files:
blocks.append(f" - /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}")
if len(all_files) > 5:
blocks.append(f"{len(all_files) - 5} more files (use read_project_folder_file by relPath)")
any_entry = True
if not any_entry:
return ""
blocks.append("</linked_folders>")
return "\n".join(blocks)
def _datetime_context_injection(context: dict[str, Any]) -> str:
"""Build a comprehensive DATE CONTEXT block with pre-computed ms-epoch boundaries for common ranges."""
fp = context.get("format_prefs")
@@ -1328,9 +1415,26 @@ async def run_home_stream(
user_id: str,
message: str,
context: dict[str, Any],
project_id: str | None = None,
) -> AsyncGenerator[tuple[str, Any], None]:
from app.agents.folder_agent import FOLDER_TOOLS
prepared_context = await _prepare_context(message, context)
system_prompt, langfuse_prompt = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, prepared_context)
manifest_block = ""
if project_id:
manifest = await _fetch_project_manifest(project_id)
manifest_block = format_folder_manifest(manifest)
if not manifest_block:
# No specific project context — surface all linked folders so the agent
# can answer questions like "tell me about project X" using its files.
manifest_block = await build_brief_multi_project_manifest()
system_prompt = system_prompt + ("\n\n" + manifest_block if manifest_block else "")
trace_id = _trace_id_from_context(prepared_context)
tools = [*_all_tools_for_user(user_id, trace_id), *FOLDER_TOOLS]
text_chunks: list[str] = []
async for event in _run_single_agent_stream(
user_id=user_id,
@@ -1339,6 +1443,7 @@ async def run_home_stream(
context=prepared_context,
langfuse_prompt=langfuse_prompt,
agent_name="home-agent",
tools=tools,
conversation_history=context.get("conversation_history"),
):
event_type, data = event
@@ -1421,6 +1526,7 @@ async def run_task_brief_research_stream(
user_id: str,
task_id: str,
context: dict[str, Any],
project_id: str | None = None,
) -> AsyncGenerator[tuple[str, Any], None]:
"""Stage-1 executive assistant: deep research for one task.
@@ -1428,8 +1534,10 @@ async def run_task_brief_research_stream(
The final concatenated text may contain a ``<canvas kind="...">...</canvas>`` block
which the WS handler strips and emits as a ``canvas_draft`` mutation.
"""
from app.agents.folder_agent import FOLDER_TOOLS
prepared_context = await _prepare_context(f"task:{task_id}", context)
tools = _brief_research_tools(user_id, _trace_id_from_context(prepared_context))
tools = [*_brief_research_tools(user_id, _trace_id_from_context(prepared_context)), *FOLDER_TOOLS]
# Inject task_id so the agent knows what to look up first.
research_message = (
@@ -1446,6 +1554,12 @@ async def run_task_brief_research_stream(
prepared_context,
)
manifest_block = ""
if project_id:
manifest = await _fetch_project_manifest(project_id)
manifest_block = format_folder_manifest(manifest)
system_prompt = system_prompt + ("\n\n" + manifest_block if manifest_block else "")
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,

183
app/core/folder_indexer.py Normal file
View File

@@ -0,0 +1,183 @@
"""Per-file summarisation for project folder integration."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass
from langchain_core.messages import HumanMessage, SystemMessage
from pypdf import PdfReader
from docx import Document as DocxDocument
from app.core.langfuse_client import (
compile_prompt,
extract_usage,
get_langfuse,
get_prompt_or_fallback,
)
from app.core.llm import get_llm
_TEXT_FALLBACK = (
"You are summarising a file for an AI assistant that helps the user manage a project.\n"
"Produce a single sentence (<=30 words, <=200 chars) that captures the file's purpose "
"and most important detail.\nFile extension: {ext}\nFile name: {name}\nContent (truncated if long):\n{content}"
)
_IMAGE_FALLBACK = (
"You are summarising an image attached to a project folder.\n"
"Produce a single sentence (<=30 words, <=200 chars) describing what the image shows "
"and any obvious purpose (logo, screenshot, diagram, photo of a whiteboard, etc.)."
)
_MAX_INPUT_CHARS = 6000
@dataclass
class IndexResult:
summary: str
tokens_used: int
async def _llm_text(messages: list) -> object:
"""Make the LLM call for text summarisation.
Defined as a standalone async function so tests can patch it cleanly
without needing to mock the LLM object itself.
"""
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
return await llm.ainvoke(messages)
async def _llm_vision(messages: list) -> object:
"""Make the LLM call for vision (image) summarisation.
Accepts the message list and returns the response directly, mirroring
the ``_llm_text`` caller pattern so tests can patch it at the module level.
"""
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
return await llm.ainvoke(messages)
async def summarize_image(*, image_b64: str, mime: str, file_name: str | None = None) -> IndexResult:
"""Return a compact summary of an image file using vision.
Parameters
----------
image_b64:
Base64-encoded image bytes.
mime:
MIME type of the image, e.g. ``"image/png"``.
file_name:
Optional file name, attached to the Langfuse trace as input metadata.
"""
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_image", _IMAGE_FALLBACK)
messages = [
SystemMessage(content=template),
HumanMessage(content=[
{"type": "text", "text": "Summarise this image."},
{"type": "image_url", "image_url": {"url": f"data:{mime};base64,{image_b64}"}},
]),
]
lf = get_langfuse()
if lf is not None:
with lf.start_as_current_observation(
as_type="generation",
name="folder-summarize-image",
model="gpt-4o-mini",
prompt=prompt_obj,
input={"file_name": file_name, "mime": mime},
) as gen:
response = await _llm_vision(messages)
usage = extract_usage(response)
gen.update(output=response.content, usage_details=usage)
else:
response = await _llm_vision(messages)
usage = extract_usage(response)
summary = (response.content or "").strip()[:500]
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
async def summarize_text(*, content: str, ext: str, name: str) -> IndexResult:
"""Return a compact summary of a text file.
Parameters
----------
content:
Raw text content of the file (will be truncated to _MAX_INPUT_CHARS).
ext:
File extension including the leading dot, e.g. ``".md"``.
name:
File name, e.g. ``"kickoff.md"``.
"""
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_text", _TEXT_FALLBACK)
truncated = content[:_MAX_INPUT_CHARS]
compiled = compile_prompt(template, prompt_obj, ext=ext, name=name, content=truncated)
messages = [
SystemMessage(content=compiled),
HumanMessage(content="Summarise this file."),
]
lf = get_langfuse()
if lf is not None:
with lf.start_as_current_observation(
as_type="generation",
name="folder-summarize-text",
model="gpt-4o-mini",
prompt=prompt_obj,
input={"file_name": name, "ext": ext, "content_chars": len(truncated)},
) as gen:
response = await _llm_text(messages)
usage = extract_usage(response)
gen.update(output=response.content, usage_details=usage)
else:
response = await _llm_text(messages)
usage = extract_usage(response)
summary = (response.content or "").strip()[:500]
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
def _extract_pdf_text(pdf_b64: str) -> str:
buf = io.BytesIO(base64.b64decode(pdf_b64))
reader = PdfReader(buf)
parts: list[str] = []
for page in reader.pages:
try:
parts.append(page.extract_text() or "")
except Exception:
continue
return "\n".join(parts).strip()
def _extract_docx_text(docx_b64: str) -> str:
buf = io.BytesIO(base64.b64decode(docx_b64))
doc = DocxDocument(buf)
return "\n".join(p.text for p in doc.paragraphs if p.text).strip()
async def summarize_pdf(*, pdf_b64: str, name: str) -> IndexResult:
"""Return a compact summary of a PDF file.
Parameters
----------
pdf_b64:
Base64-encoded PDF bytes.
name:
File name, e.g. ``"report.pdf"``.
"""
text = _extract_pdf_text(pdf_b64)
if not text:
return IndexResult(summary="Could not extract text", tokens_used=0)
return await summarize_text(content=text, ext=".pdf", name=name)
async def summarize_docx(*, docx_b64: str, name: str) -> IndexResult:
"""Return a compact summary of a DOCX file.
Parameters
----------
docx_b64:
Base64-encoded DOCX bytes.
name:
File name, e.g. ``"spec.docx"``.
"""
text = _extract_docx_text(docx_b64)
if not text:
return IndexResult(summary="Could not extract text", tokens_used=0)
return await summarize_text(content=text, ext=".docx", name=name)

View File

@@ -243,6 +243,7 @@ class AgentRunLog(Base):
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
started_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
@@ -263,6 +264,17 @@ class AgentRunLog(Base):
)
class MonthlyTokenUsage(Base):
__tablename__ = "monthly_token_usage"
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
)
year_month: Mapped[str] = mapped_column(String(7), primary_key=True) # 'YYYY-MM'
feature: Mapped[str] = mapped_column(String(64), primary_key=True)
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
# ── Memory models ─────────────────────────────────────────────────────────────

View File

@@ -89,6 +89,13 @@ class WsFrameType(str, Enum):
brief_request = "brief_request"
# ── v6 task brief frame types ─────────────────────────────────────
task_brief_request = "task_brief_request"
# ── v7 folder index frame types ───────────────────────────────────
index_session_start = "index_session_start"
index_file_batch = "index_file_batch"
index_session_cancel = "index_session_cancel"
index_file_result = "index_file_result"
index_session_progress = "index_session_progress"
index_session_done = "index_session_done"
class WsToolCall(BaseModel):

View File

@@ -39,3 +39,5 @@ lxml>=5.0.0
PyYAML>=6.0.0
apscheduler>=3.10.0
ruff>=0.8.0
pypdf>=4.0
python-docx>=1.1

View File

@@ -17,6 +17,8 @@ from jose import jwt
from sqlalchemy import StaticPool, event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy import select
from app.config.settings import settings
from app.db import Base, get_session
from app.main import app
@@ -134,6 +136,38 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
# ── Convenience aliases and per-tier user fixtures ────────────────────
@pytest_asyncio.fixture
async def db(db_session: AsyncSession) -> AsyncSession:
"""Alias for db_session — used by folder quota tests."""
return db_session
@pytest_asyncio.fixture
async def test_user_free(db_session: AsyncSession):
"""Return the seeded free-tier User row."""
result = await db_session.execute(
select(User).where(User.id == TEST_USER_IDS["free"])
)
return result.scalar_one()
@pytest_asyncio.fixture
async def test_user_power(db_session: AsyncSession):
"""Return the seeded power-tier User row."""
result = await db_session.execute(
select(User).where(User.id == TEST_USER_IDS["power"])
)
return result.scalar_one()
@pytest.fixture
def auth_headers_free() -> dict[str, str]:
"""Authorization header for the seeded free-tier user."""
return auth_header("free")
# ── CLI options ───────────────────────────────────────────────────────
def pytest_addoption(parser):

View File

@@ -0,0 +1,139 @@
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.agents.folder_agent import (
read_project_folder_file,
search_project_folder_file,
)
pytestmark = pytest.mark.asyncio
async def test_happy_path():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "file body", "kind": "text", "totalSize": 9}),
):
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "docs/x.md"})
assert "file body" in out
assert "kind=text" in out
async def test_traversal_rejected():
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "../../etc/passwd"})
assert out == "Access denied"
async def test_absolute_path_rejected():
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "C:\\Windows\\foo"})
assert out == "Access denied"
async def test_missing_file():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "", "kind": "missing", "totalSize": 0}),
):
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "ghost.md"})
assert "not found" in out.lower()
async def test_pagination_signals_more_available():
# Electron returned the first slice, totalSize larger than slice length.
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "first chunk", "kind": "text", "totalSize": 1000}),
):
out = await read_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "big.txt",
"offset": 0,
"length": 11,
})
assert "first chunk" in out
assert "More content available" in out
assert "offset=11" in out
async def test_pdf_extracted_then_sliced(monkeypatch):
from app.agents import folder_agent
monkeypatch.setattr(folder_agent, "_extract_pdf_text", lambda b: "ABC " * 100)
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "JVBERi0xLg==", "kind": "pdf", "totalSize": 12}),
):
out = await read_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "doc.pdf",
"offset": 0,
"length": 8,
})
assert "kind=pdf" in out
assert "ABC ABC " in out
assert "More content available" in out
async def test_image_returns_placeholder():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "iVBORw0K", "kind": "image", "totalSize": 1024}),
):
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "logo.png"})
assert "image" in out.lower()
async def test_search_finds_match_with_context():
body = "alpha\nbeta\nthe needle is here\ngamma\ndelta"
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": body, "kind": "text", "totalSize": len(body)}),
):
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "log.txt",
"query": "needle",
"context_lines": 1,
})
assert "needle" in out
assert "matches=1" in out
# Context lines included
assert "beta" in out
assert "gamma" in out
async def test_search_no_match():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "nothing here", "kind": "text", "totalSize": 12}),
):
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "x.txt",
"query": "zzz",
})
assert "No matches" in out
async def test_search_rejects_traversal():
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "../etc/passwd",
"query": "root",
})
assert out == "Access denied"
async def test_search_image_rejected():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "b64data", "kind": "image", "totalSize": 100}),
):
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "logo.png",
"query": "anything",
})
assert "Cannot search" in out

View File

@@ -0,0 +1,83 @@
"""Folder indexer LLM helpers."""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.core.folder_indexer import summarize_text, summarize_image, IndexResult
pytestmark = pytest.mark.asyncio
async def test_summarize_text_returns_summary_and_tokens():
mock_resp = AsyncMock()
mock_resp.content = "Kickoff notes covering scope and deadlines."
mock_resp.usage_metadata = {"input_tokens": 320, "output_tokens": 18, "total_tokens": 338}
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
result = await summarize_text(content="hello world", ext=".md", name="kickoff.md")
assert isinstance(result, IndexResult)
assert result.summary == "Kickoff notes covering scope and deadlines."
assert result.tokens_used == 338
async def test_summarize_text_truncates_summary_at_500_chars():
mock_resp = AsyncMock()
mock_resp.content = "x" * 1000
mock_resp.usage_metadata = {"total_tokens": 100}
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
result = await summarize_text(content="x", ext=".md", name="x.md")
assert len(result.summary) <= 500
async def test_summarize_image_uses_vision_content_blocks():
mock_resp = AsyncMock()
mock_resp.content = "Final logo on white background."
mock_resp.usage_metadata = {"total_tokens": 500}
captured = {}
async def fake_llm_vision(messages):
captured["messages"] = messages
return mock_resp
with patch("app.core.folder_indexer._llm_vision", new=fake_llm_vision):
result = await summarize_image(image_b64="iVBORw0KG", mime="image/png")
assert "Final logo" in result.summary
assert result.tokens_used == 500
# last message contains an image content block
last = captured["messages"][-1]
assert any(
isinstance(p, dict) and p.get("type") == "image_url"
for p in (last.content if isinstance(last.content, list) else [])
)
async def test_summarize_pdf_extracts_then_summarizes(monkeypatch):
# pypdf.PdfReader returns text from pages
from app.core import folder_indexer
class FakePage:
def extract_text(self): return "PDF page content with project info."
class FakeReader:
pages = [FakePage(), FakePage()]
monkeypatch.setattr(folder_indexer, "PdfReader", lambda buf: FakeReader())
mock_resp = AsyncMock(); mock_resp.content = "Project info doc."; mock_resp.usage_metadata = {"total_tokens": 50}
async def fake_llm(messages): return mock_resp
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
result = await folder_indexer.summarize_pdf(pdf_b64="SGVsbG8=", name="doc.pdf")
assert "Project info" in result.summary
assert result.tokens_used == 50
async def test_summarize_docx_extracts_then_summarizes(monkeypatch):
from app.core import folder_indexer
class FakePara:
def __init__(self, t): self.text = t
class FakeDoc:
paragraphs = [FakePara("Heading"), FakePara("Body paragraph one.")]
monkeypatch.setattr(folder_indexer, "DocxDocument", lambda buf: FakeDoc())
mock_resp = AsyncMock(); mock_resp.content = "Heading and body."; mock_resp.usage_metadata = {"total_tokens": 30}
async def fake_llm(messages): return mock_resp
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
result = await folder_indexer.summarize_docx(docx_b64="UEsDBBQ=", name="doc.docx")
assert result.summary == "Heading and body."

View File

@@ -0,0 +1,94 @@
"""Folder quota helpers."""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from sqlalchemy import select
from app.billing.quota import (
check_folder_quota,
add_token_usage,
QuotaExceeded,
)
from app.models import MonthlyTokenUsage
pytestmark = pytest.mark.asyncio
async def test_check_folder_quota_free_rejects_above_file_cap(db, test_user_free):
with pytest.raises(QuotaExceeded) as exc:
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=500, db=db
)
assert exc.value.reason == "max_files"
async def test_check_folder_quota_free_passes_under_cap(db, test_user_free):
# No raise
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=50, db=db
)
async def test_check_folder_quota_rejects_when_monthly_exhausted(db, test_user_free):
ym = datetime.now(timezone.utc).strftime("%Y-%m")
db.add(MonthlyTokenUsage(
user_id=test_user_free.id, year_month=ym, feature="folder_index", tokens_used=100_000
))
await db.commit()
with pytest.raises(QuotaExceeded) as exc:
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=10, db=db
)
assert exc.value.reason == "monthly_tokens"
async def test_check_folder_quota_power_unlimited(db, test_user_power):
await check_folder_quota(
user_id=test_user_power.id, tier="power", estimated_files=999_999, db=db
)
async def test_add_token_usage_atomic_increment(db, test_user_free):
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=1500, db=db)
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=2500, db=db)
ym = datetime.now(timezone.utc).strftime("%Y-%m")
row = (await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == test_user_free.id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == "folder_index",
)
)).scalar_one()
assert row.tokens_used == 4000
async def test_add_token_usage_returns_exhausted_when_over_cap(db, test_user_free):
result = await add_token_usage(
user_id=test_user_free.id, feature="folder_index", tokens=150_000, db=db, cap=100_000
)
assert result.exhausted is True
assert result.tokens_used == 150_000
def test_quota_check_endpoint_rejects(client, auth_headers_free):
res = client.post(
"/api/v1/billing/quota/check",
json={"feature": "folder_index", "estimated_files": 500},
headers=auth_headers_free,
)
assert res.status_code == 402
body = res.json()
assert body["detail"]["reason"] == "max_files"
def test_quota_check_endpoint_passes(client, auth_headers_free):
res = client.post(
"/api/v1/billing/quota/check",
json={"feature": "folder_index", "estimated_files": 50},
headers=auth_headers_free,
)
assert res.status_code == 200
assert res.json() == {"ok": True}

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.core.deep_agent import format_folder_manifest, MANIFEST_TOKEN_BUDGET
pytestmark = pytest.mark.asyncio
def test_format_folder_manifest_basic():
manifest = {
"folderPath": "D:\\Acme",
"lastScannedAt": "2h ago",
"files": [
{"relPath": "briefs/kickoff.md", "kind": "text", "summary": "Kickoff notes; scope and deadlines."},
{"relPath": "logos/logo-v3.png", "kind": "image", "summary": "Final logo on white."},
],
}
out = format_folder_manifest(manifest)
assert "<linked_folder>" in out
assert "/briefs/kickoff.md" in out or "briefs/kickoff.md" in out
assert "[text]" in out
assert "[image]" in out
def test_format_folder_manifest_truncates_past_budget():
files = [
{"relPath": f"f{i}.md", "kind": "text", "summary": "x" * 100, "mtimeMs": i}
for i in range(2000)
]
out = format_folder_manifest({"folderPath": "p", "lastScannedAt": "now", "files": files})
assert "more files omitted" in out
# Rough token check
assert len(out) // 4 < MANIFEST_TOKEN_BUDGET + 200
def test_format_folder_manifest_null_returns_empty():
assert format_folder_manifest(None) == ""
assert format_folder_manifest({"files": []}) == ""
async def test_brief_multi_project_manifest_top_5_per_project():
fake_response = [
{
"projectId": "p1", "projectName": "Acme", "folderPath": "/a",
"lastScannedAt": "now",
"files": [
{"relPath": f"f{i}.md", "kind": "text", "summary": "s", "mtimeMs": i}
for i in range(10)
],
},
{
"projectId": "p2", "projectName": "Beta", "folderPath": "/b",
"lastScannedAt": "now",
"files": [{"relPath": "x.md", "kind": "text", "summary": "s", "mtimeMs": 1}],
},
]
with patch(
"app.core.deep_agent.execute_on_client",
new=AsyncMock(return_value={"projects": fake_response}),
):
from app.core.deep_agent import build_brief_multi_project_manifest
out = await build_brief_multi_project_manifest()
# Project 1 has 10 files, only top 5 by mtimeMs should appear
assert out.count("[p1]") <= 5
# Project 2 has 1 file, must appear
assert "[p2]" in out or "Beta" in out

View File

@@ -0,0 +1,196 @@
"""Tests for WS folder index_session handlers (Task 9).
Tests the three handler functions directly with a minimal fake WebSocket so
no real WS connection or LLM call is made.
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
import pytest_asyncio
from app.api.routes.device_ws import (
_handle_index_session_start,
_handle_index_file_batch,
_handle_index_session_cancel,
_index_sessions,
)
from app.billing.quota import add_token_usage
from app.core.folder_indexer import IndexResult
from app.models import MonthlyTokenUsage
from app.schemas import WsFrameType
from tests.conftest import TEST_USER_IDS
pytestmark = pytest.mark.asyncio
USER_ID = TEST_USER_IDS["free"]
POWER_USER_ID = TEST_USER_IDS["power"]
# ── Fake WebSocket ────────────────────────────────────────────────────
class _FakeWebSocket:
"""Minimal WebSocket stand-in that records send_text calls."""
def __init__(self) -> None:
self.sent: list[dict] = []
async def send_text(self, text: str) -> None:
self.sent.append(json.loads(text))
def sent_types(self) -> list[str]:
return [f["type"] for f in self.sent]
# ── Helpers ───────────────────────────────────────────────────────────
def _make_session_id() -> str:
import uuid
return str(uuid.uuid4())
def _fake_summarize_text_factory(summary: str = "A test summary.", tokens: int = 100):
"""Return an AsyncMock that resolves to a fixed IndexResult."""
async def _fake(**kwargs) -> IndexResult:
return IndexResult(summary=summary, tokens_used=tokens)
return _fake
# ── Fixtures ──────────────────────────────────────────────────────────
@pytest_asyncio.fixture(autouse=True)
async def _clean_sessions():
"""Ensure _index_sessions is empty before and after each test."""
_index_sessions.clear()
yield
_index_sessions.clear()
# ── Tests ─────────────────────────────────────────────────────────────
async def test_index_session_happy_path(db_session):
"""start + batch of 2 text files → 2 index_file_result + 1 progress + 1 done(completed)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
# Register session.
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"projectId": "proj-1",
"totalFiles": 2,
})
# Verify session was registered.
assert session_id in _index_sessions
assert _index_sessions[session_id]["total"] == 2
assert _index_sessions[session_id]["processed"] == 0
# No response frames expected for session_start.
assert ws.sent == []
# Send batch of 2 text files — patch summarize_text so no LLM call needed.
with patch(
"app.api.routes.device_ws._handle_index_file_batch.__globals__",
# We patch the module-level function in folder_indexer instead:
) if False else patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory()):
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
# Wire db_session into the context manager.
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_async_session.return_value = mock_cm
await _handle_index_file_batch(ws, USER_ID, {
"sessionId": session_id,
"files": [
{"relPath": "README.md", "kind": "text", "content": "hello", "ext": ".md"},
{"relPath": "notes.txt", "kind": "text", "content": "world", "ext": ".txt"},
],
})
types = ws.sent_types()
# Expect 2 file results + 1 progress + 1 done(completed).
assert types.count(WsFrameType.index_file_result) == 2
assert types.count(WsFrameType.index_session_progress) == 1
assert types.count(WsFrameType.index_session_done) == 1
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "completed"
progress_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_progress)
assert progress_frame["processed"] == 2
assert progress_frame["total"] == 2
# Verify session cleaned up.
assert session_id not in _index_sessions
async def test_index_session_cancel(db_session):
"""start then cancel → index_session_done(cancelled)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"totalFiles": 5,
})
assert session_id in _index_sessions
await _handle_index_session_cancel(ws, {"sessionId": session_id})
types = ws.sent_types()
assert WsFrameType.index_session_done in types
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "cancelled"
# Session should be cleaned up.
assert session_id not in _index_sessions
async def test_index_session_quota_exceeded(db_session):
"""Pre-fill usage to cap → batch one file → index_session_done(quota_exceeded)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
# Pre-fill monthly token usage to the free-tier cap (100_000).
ym = datetime.now(timezone.utc).strftime("%Y-%m")
db_session.add(MonthlyTokenUsage(
user_id=USER_ID,
year_month=ym,
feature="folder_index",
tokens_used=100_000, # free tier cap exactly
))
await db_session.commit()
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"totalFiles": 1,
})
with patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory(tokens=1)):
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_async_session.return_value = mock_cm
await _handle_index_file_batch(ws, USER_ID, {
"sessionId": session_id,
"files": [
{"relPath": "file.md", "kind": "text", "content": "content", "ext": ".md"},
],
})
types = ws.sent_types()
# Should have 1 file result (success) then done(quota_exceeded).
assert WsFrameType.index_file_result in types
assert WsFrameType.index_session_done in types
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "quota_exceeded"
# Session should be cleaned up.
assert session_id not in _index_sessions