15 Commits

Author SHA1 Message Date
Roberto
956fa88853 feat(api): multi-project folder manifest for daily brief
Add build_brief_multi_project_manifest() to deep_agent.py that fetches
all project folder manifests via execute_on_client and keeps the top 5
most-recently-modified files per project. Wire into run_home_brief in
brief_agent.py, injecting the <linked_folders> block into the system
prompt alongside FOLDER_TOOLS.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:40:47 +02:00
Roberto
fb2f59ccea feat(api): inject folder manifest into home agent when project context active
Add optional project_id param to run_home_stream. When set, fetch the linked
folder manifest via _fetch_project_manifest and prepend the <linked_folder>
block to the system prompt. Also build an explicit tools list that extends
_all_tools_for_user with FOLDER_TOOLS so the home agent can read folder
files. device_ws._handle_home_request extracts project_id / projectId from
the home_request frame and forwards it to the runner.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:32:20 +02:00
Roberto
56dbb7f4cd feat(api): inject folder manifest into task brief agent
Add _fetch_project_manifest helper that calls read_project_folder_manifest
via execute_on_client. Wire it into run_task_brief_research_stream (new
optional project_id param) so the <linked_folder> block is prepended to the
system prompt when the task belongs to a linked project. Also bind
FOLDER_TOOLS into the task-brief tool palette so the agent can read folder
files. device_ws extracts project_id / projectId from the task_brief_request
frame and forwards it.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:31:21 +02:00
Roberto
506f517851 feat(api): manifest formatter with token-budget truncation 2026-05-12 11:28:13 +02:00
Roberto
520c186991 feat(api): scoped read_project_folder_file tool with traversal guard
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:26:02 +02:00
Roberto
582bf27deb feat(api): WS index_session frames + handlers
Add six v7 WsFrameType enum members (index_session_start/cancel/batch,
index_file_result/progress/done), wire dispatch in device_ws message loop,
and implement _handle_index_session_start/cancel/file_batch with per-file
summarisation, token accounting, and quota enforcement.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:22:20 +02:00
Roberto
2aeb453229 feat(api): PDF + DOCX extraction in folder indexer
Add pypdf/python-docx deps, _extract_pdf_text/_extract_docx_text helpers,
and summarize_pdf/summarize_docx wrappers that delegate to summarize_text.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:15:17 +02:00
Roberto
b7a4edac90 feat(api): folder_indexer.summarize_image via gpt-4o-mini vision
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:09:37 +02:00
Roberto
822b4cd8b1 feat(api): folder_indexer.summarize_text via gpt-4o-mini 2026-05-12 11:05:43 +02:00
Roberto
ab24fc4c91 feat(api): POST /billing/quota/check endpoint
Pre-flight quota check for folder_index. Returns 402 with reason
when file cap or monthly token budget would be exceeded; 200 {"ok": true}
otherwise. Also adds auth_headers_free fixture to conftest.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 09:14:56 +02:00
Roberto
a98e99f7a2 feat(api): folder quota helpers with atomic token usage
Implements check_folder_quota and add_token_usage in app/billing/quota.py
with dialect-aware upsert (pg_insert on PostgreSQL, read-then-write on SQLite).
Adds test_user_free/test_user_power fixtures and db alias to conftest.py.
6 new tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 08:23:22 +02:00
Roberto
a0ff285bcd feat(api): tier features for folder integration
Add folder_max_files and folder_monthly_tokens to all four tier dicts
in FEATURES, and add get_feature_value() helper to TierManager.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:39:36 +02:00
Roberto
177c1a87dd feat(api): MonthlyTokenUsage model + AgentRunLog.tokens_used
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:30:33 +02:00
Roberto
441a4ea05c chore(api): fix stale Revises comment in folder migration 2026-05-12 07:21:13 +02:00
Roberto
a693a64bf5 feat(api): add migration for folder token tracking
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:16:23 +02:00
18 changed files with 1262 additions and 8 deletions

View File

@@ -0,0 +1,46 @@
"""Add token tracking columns for folder integration.
Revision ID: d6e3f4a5b6c7
Revises: 006
Create Date: 2026-05-11 00:00:00.000000
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = "d6e3f4a5b6c7"
down_revision: Union[str, None] = "006"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"agent_run_logs",
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
)
op.create_table(
"monthly_token_usage",
sa.Column("user_id", UUID(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("year_month", sa.String(7), nullable=False),
sa.Column("feature", sa.String(64), nullable=False),
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
sa.PrimaryKeyConstraint("user_id", "year_month", "feature"),
)
op.create_index(
"ix_monthly_token_usage_user_month",
"monthly_token_usage",
["user_id", "year_month"],
)
def downgrade() -> None:
op.drop_index("ix_monthly_token_usage_user_month", table_name="monthly_token_usage")
op.drop_table("monthly_token_usage")
op.drop_column("agent_run_logs", "tokens_used")

View File

@@ -0,0 +1,37 @@
"""Scoped file-read tool for the project folder feature."""
from __future__ import annotations
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
def _is_unsafe_path(rel: str) -> bool:
if not rel:
return True
norm = rel.replace("\\", "/")
if norm.startswith("/"):
return True
# Windows drive letter
if len(rel) >= 2 and rel[1] == ":":
return True
parts = norm.split("/")
return ".." in parts
@tool
async def read_project_folder_file(project_id: str, relative_path: str) -> str:
"""Read full content of a file inside the project's linked folder."""
if _is_unsafe_path(relative_path):
return "Access denied"
result = await execute_on_client(
action="read_project_folder_file",
data={"projectId": project_id, "relativePath": relative_path},
)
content = result.get("content", "")
if not content:
return f"File not found: {relative_path}"
return content
FOLDER_TOOLS = [read_project_folder_file]

View File

@@ -9,7 +9,7 @@ from __future__ import annotations
from typing import Any from typing import Any
from fastapi import APIRouter, Depends, Header, Request, status from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -96,3 +96,37 @@ async def list_invoices(
""" """
invoices = await stripe_service.list_invoices(current_user.id, db) invoices = await stripe_service.list_invoices(current_user.id, db)
return invoices return invoices
# ── Quota check ────────────────────────────────────────────────────────
from app.billing.quota import check_folder_quota, QuotaExceeded # noqa: E402
class QuotaCheckRequest(BaseModel):
feature: str
estimated_files: int
@router.post("/quota/check")
async def quota_check(
payload: QuotaCheckRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict:
"""Pre-flight folder quota check. 402 if tier limits would be exceeded."""
if payload.feature != "folder_index":
raise HTTPException(status_code=400, detail="Unknown feature")
try:
await check_folder_quota(
user_id=current_user.id,
tier=current_user.tier,
estimated_files=payload.estimated_files,
db=db,
)
except QuotaExceeded as exc:
raise HTTPException(
status_code=402,
detail={"reason": exc.reason, "message": str(exc)},
)
return {"ok": True}

View File

@@ -57,6 +57,10 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ws", tags=["device-ws"]) router = APIRouter(prefix="/ws", tags=["device-ws"])
# ── v7 folder index session state ─────────────────────────────────────
# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled }
_index_sessions: dict[str, dict] = {}
_HEARTBEAT_INTERVAL = 30 # seconds _HEARTBEAT_INTERVAL = 30 # seconds
_PONG_TIMEOUT = 10 # seconds — grace window after a ping _PONG_TIMEOUT = 10 # seconds — grace window after a ping
@@ -180,6 +184,19 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
_handle_journey_message(websocket, user_id, frame) _handle_journey_message(websocket, user_id, frame)
) )
elif frame_type == WsFrameType.index_session_start:
asyncio.create_task(
_handle_index_session_start(websocket, user_id, frame)
)
elif frame_type == WsFrameType.index_file_batch:
asyncio.create_task(
_handle_index_file_batch(websocket, user_id, frame)
)
elif frame_type == WsFrameType.index_session_cancel:
await _handle_index_session_cancel(websocket, frame)
elif frame_type == "pong": elif frame_type == "pong":
# Heartbeat ack — nothing to do, connection is alive. # Heartbeat ack — nothing to do, connection is alive.
pass pass
@@ -211,11 +228,13 @@ async def _handle_home_request(
request_id = frame.get("request_id") or str(uuid4()) request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "") message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4()) session_id: str = frame.get("session_id") or str(uuid4())
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
logger.info( logger.info(
"device_ws: home_request_start user=%s req=%s session=%s msg=%s", "device_ws: home_request_start user=%s req=%s session=%s project=%s msg=%s",
user_id, user_id,
request_id, request_id,
session_id, session_id,
project_id,
message[:200], message[:200],
) )
@@ -240,7 +259,7 @@ async def _handle_home_request(
set_client_executor(executor) set_client_executor(executor)
response_chunks: list[str] = [] response_chunks: list[str] = []
try: try:
event_stream = run_home_stream(user_id, message, context) event_stream = run_home_stream(user_id, message, context, project_id=project_id)
formatter = StreamFormatter(request_id=request_id) formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream): async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json()) await websocket.send_text(ws_frame.model_dump_json())
@@ -437,10 +456,11 @@ async def _handle_task_brief_request(
request_id = frame.get("request_id") or str(uuid4()) request_id = frame.get("request_id") or str(uuid4())
session_id = frame.get("session_id") or str(uuid4()) session_id = frame.get("session_id") or str(uuid4())
task_id: str = frame.get("task_id") or frame.get("taskId") or "" task_id: str = frame.get("task_id") or frame.get("taskId") or ""
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
logger.info( logger.info(
"device_ws: task_brief_request_start user=%s req=%s task=%s [cache_miss]", "device_ws: task_brief_request_start user=%s req=%s task=%s project=%s [cache_miss]",
user_id, request_id, task_id, user_id, request_id, task_id, project_id,
) )
if not task_id: if not task_id:
@@ -469,7 +489,7 @@ async def _handle_task_brief_request(
response_chunks: list[str] = [] response_chunks: list[str] = []
try: try:
event_stream = run_task_brief_research_stream(user_id, task_id, context) event_stream = run_task_brief_research_stream(user_id, task_id, context, project_id=project_id)
formatter = StreamFormatter(request_id=request_id) formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream): async for ws_frame in formatter.format(event_stream):
if ws_frame.type == "stream_text": # type: ignore[union-attr] if ws_frame.type == "stream_text": # type: ignore[union-attr]
@@ -569,6 +589,173 @@ async def _handle_journey_message(
clear_client_executor() clear_client_executor()
# ── v7 Folder Index Handlers ──────────────────────────────────────────
async def _handle_index_session_start(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Register a new folder index session. No response sent — client is declaring intent."""
session_id: str = frame.get("sessionId") or frame.get("session_id", "")
project_id: str | None = frame.get("projectId") or frame.get("project_id")
total: int = int(frame.get("totalFiles", 0))
if not session_id:
logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id)
return
_index_sessions[session_id] = {
"user_id": user_id,
"project_id": project_id,
"processed": 0,
"total": total,
"cancelled": False,
}
logger.info(
"device_ws: index_session_start user=%s session=%s project=%s total=%d",
user_id, session_id, project_id, total,
)
async def _handle_index_session_cancel(
websocket: WebSocket,
frame: dict,
) -> None:
"""Mark a session as cancelled and emit index_session_done(cancelled)."""
session_id: str = frame.get("sessionId") or frame.get("session_id", "")
session = _index_sessions.get(session_id)
if session:
session["cancelled"] = True
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "cancelled",
}))
_index_sessions.pop(session_id, None)
logger.info("device_ws: index_session_cancel session=%s", session_id)
async def _handle_index_file_batch(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Process a batch of files for an index session, streaming results back."""
# Lazy imports to avoid heavy load at module startup.
from app.core.folder_indexer import ( # noqa: PLC0415
summarize_image,
summarize_pdf,
summarize_docx,
summarize_text,
)
from app.billing.tier_manager import tier_manager # noqa: PLC0415
from app.billing.quota import add_token_usage # noqa: PLC0415
session_id: str = frame.get("sessionId") or frame.get("session_id", "")
files: list[dict] = frame.get("files", [])
session = _index_sessions.get(session_id)
if not session or session.get("cancelled"):
return
async with async_session() as db:
tier = await tier_manager.get_tier(user_id, db)
raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens")
cap: int | None = None if raw_cap == -1 else raw_cap
for file_info in files:
if session.get("cancelled"):
return
rel_path: str = file_info.get("relPath", "")
kind: str = file_info.get("kind", "text")
content: str = file_info.get("content", "")
ext: str = file_info.get("ext", "")
mime: str = file_info.get("mime", "application/octet-stream")
name: str = rel_path.split("/")[-1] or rel_path
try:
if kind == "image":
res = await summarize_image(image_b64=content, mime=mime)
elif kind == "pdf":
res = await summarize_pdf(pdf_b64=content, name=name)
elif kind == "docx":
res = await summarize_docx(docx_b64=content, name=name)
else:
res = await summarize_text(content=content, ext=ext, name=name)
except Exception as exc:
logger.warning(
"device_ws: index_file_batch summarize failed session=%s path=%s: %s",
session_id, rel_path, exc,
)
await websocket.send_text(json.dumps({
"type": WsFrameType.index_file_result,
"sessionId": session_id,
"relPath": rel_path,
"summary": None,
"tokensUsed": 0,
"error": str(exc),
}))
session["processed"] += 1
continue
# Account for token usage and check cap.
usage = await add_token_usage(
user_id=user_id,
feature="folder_index",
tokens=res.tokens_used,
db=db,
cap=cap,
)
await websocket.send_text(json.dumps({
"type": WsFrameType.index_file_result,
"sessionId": session_id,
"relPath": rel_path,
"summary": res.summary,
"tokensUsed": res.tokens_used,
}))
session["processed"] += 1
if usage.exhausted:
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "quota_exceeded",
}))
_index_sessions.pop(session_id, None)
logger.info(
"device_ws: index_session quota_exceeded user=%s session=%s",
user_id, session_id,
)
return
# After processing the batch, emit progress.
processed = session["processed"]
total = session["total"]
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_progress,
"sessionId": session_id,
"processed": processed,
"total": total,
}))
if processed >= total:
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "completed",
}))
_index_sessions.pop(session_id, None)
logger.info(
"device_ws: index_session_done completed user=%s session=%s processed=%d",
user_id, session_id, processed,
)
# ── Heartbeat ───────────────────────────────────────────────────────── # ── Heartbeat ─────────────────────────────────────────────────────────
async def _heartbeat_loop(websocket: WebSocket) -> None: async def _heartbeat_loop(websocket: WebSocket) -> None:

139
app/billing/quota.py Normal file
View File

@@ -0,0 +1,139 @@
"""Quota checks and atomic token-usage accounting for folder integration."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from sqlalchemy import select, update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from app.billing.tier_manager import TierManager
from app.models import MonthlyTokenUsage
from app.schemas import BillingTier
class QuotaExceeded(Exception):
"""Raised when a folder operation cannot proceed under the user's tier."""
def __init__(self, reason: str, message: str) -> None:
super().__init__(message)
self.reason = reason # "max_files" | "monthly_tokens"
@dataclass
class TokenUsageResult:
tokens_used: int
exhausted: bool
def _current_year_month() -> str:
return datetime.now(timezone.utc).strftime("%Y-%m")
_tier_manager = TierManager()
async def check_folder_quota(
*,
user_id: str,
tier: BillingTier,
estimated_files: int,
db: AsyncSession,
) -> None:
"""Raise QuotaExceeded if folder_max_files or folder_monthly_tokens
would be violated. -1 in either feature means unlimited."""
max_files = _tier_manager.get_feature_value(tier, "folder_max_files")
if max_files != -1 and estimated_files > max_files:
raise QuotaExceeded(
"max_files",
f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.",
)
cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens")
if cap == -1:
return
ym = _current_year_month()
row = (
await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == user_id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == "folder_index",
)
)
).scalar_one_or_none()
used = row.tokens_used if row else 0
if used >= cap:
raise QuotaExceeded(
"monthly_tokens",
f"Monthly token budget exhausted ({used}/{cap}); resets next month.",
)
async def add_token_usage(
*,
user_id: str,
feature: str,
tokens: int,
db: AsyncSession,
cap: int | None = None,
) -> TokenUsageResult:
"""Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature).
Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls
back to a read-then-write on other engines (e.g. aiosqlite in tests).
Returns post-update total and whether cap is exhausted.
"""
ym = _current_year_month()
# Detect dialect to choose between native upsert and portable fallback.
dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr]
if dialect_name == "postgresql":
# Native atomic upsert — production path.
stmt = (
pg_insert(MonthlyTokenUsage)
.values(
user_id=user_id,
year_month=ym,
feature=feature,
tokens_used=tokens,
)
.on_conflict_do_update(
index_elements=["user_id", "year_month", "feature"],
set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens},
)
.returning(MonthlyTokenUsage.tokens_used)
)
used: int = (await db.execute(stmt)).scalar_one()
await db.commit()
else:
# Portable fallback — used in tests (SQLite) and any non-PG engine.
row = (
await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == user_id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == feature,
)
)
).scalar_one_or_none()
if row is None:
row = MonthlyTokenUsage(
user_id=user_id,
year_month=ym,
feature=feature,
tokens_used=tokens,
)
db.add(row)
else:
row.tokens_used += tokens
await db.commit()
await db.refresh(row)
used = row.tokens_used
exhausted = cap is not None and cap != -1 and used >= cap
return TokenUsageResult(tokens_used=used, exhausted=exhausted)

View File

@@ -29,6 +29,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": False, # batch queue (Phase 2) "realtime_extraction": False, # batch queue (Phase 2)
"relational_memory": False, # relational tier (Phase 3) — Pro+ "relational_memory": False, # relational tier (Phase 3) — Pro+
"proactive_mining": False, # Power+ only (Phase 5) "proactive_mining": False, # Power+ only (Phase 5)
"folder_max_files": 200,
"folder_monthly_tokens": 100_000,
}, },
"pro": { "pro": {
"agents": -1, # unlimited "agents": -1, # unlimited
@@ -41,6 +43,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": True, # fire-and-forget asyncio.create_task "realtime_extraction": True, # fire-and-forget asyncio.create_task
"relational_memory": True, # person/project predicates "relational_memory": True, # person/project predicates
"proactive_mining": False, # Power+ only (Phase 5) "proactive_mining": False, # Power+ only (Phase 5)
"folder_max_files": 5000,
"folder_monthly_tokens": 2_000_000,
}, },
"power": { "power": {
"agents": -1, "agents": -1,
@@ -53,6 +57,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": True, "realtime_extraction": True,
"relational_memory": True, # all predicates incl. custom "relational_memory": True, # all predicates incl. custom
"proactive_mining": True, # scheduled pattern mining (Phase 5) "proactive_mining": True, # scheduled pattern mining (Phase 5)
"folder_max_files": -1, # unlimited
"folder_monthly_tokens": -1, # unlimited
}, },
"team": { "team": {
"agents": -1, "agents": -1,
@@ -65,6 +71,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": True, "realtime_extraction": True,
"relational_memory": True, # all predicates incl. custom "relational_memory": True, # all predicates incl. custom
"proactive_mining": True, # scheduled pattern mining (Phase 5) "proactive_mining": True, # scheduled pattern mining (Phase 5)
"folder_max_files": -1, # unlimited
"folder_monthly_tokens": -1, # unlimited
}, },
} }
@@ -123,6 +131,13 @@ class TierManager:
) )
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def get_feature_value(self, tier: BillingTier, feature: str) -> int:
"""Return integer feature value for tier. -1 means unlimited."""
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
if not isinstance(value, int):
return 0
return value
# ── Rate limiting ──────────────────────────────────────────────────── # ── Rate limiting ────────────────────────────────────────────────────
def get_rate_limit(self, tier: BillingTier) -> int: def get_rate_limit(self, tier: BillingTier) -> int:

View File

@@ -21,6 +21,7 @@ from app.core.deep_agent import (
_relational_memory_injection, _relational_memory_injection,
_run_single_agent_stream, _run_single_agent_stream,
_trace_id_from_context, _trace_id_from_context,
build_brief_multi_project_manifest,
) )
from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback
@@ -159,6 +160,8 @@ async def run_home_brief(
Yields (event_type, data) tuples identical to _run_single_agent_stream. Yields (event_type, data) tuples identical to _run_single_agent_stream.
Do NOT post-process output through _normalize_tagged_list_lines. Do NOT post-process output through _normalize_tagged_list_lines.
""" """
from app.agents.folder_agent import FOLDER_TOOLS
trace_id = _trace_id_from_context(context) trace_id = _trace_id_from_context(context)
today = date.today().isoformat() today = date.today().isoformat()
language = _resolve_language(context) language = _resolve_language(context)
@@ -171,7 +174,10 @@ async def run_home_brief(
if today not in system_prompt: if today not in system_prompt:
system_prompt += f"\nToday is {today}." system_prompt += f"\nToday is {today}."
tools = _build_read_tools(user_id, trace_id) brief_manifest = await build_brief_multi_project_manifest()
system_prompt = system_prompt + ("\n\n" + brief_manifest if brief_manifest else "")
tools = [*_build_read_tools(user_id, trace_id), *FOLDER_TOOLS]
async for event in _run_single_agent_stream( async for event in _run_single_agent_stream(
user_id=user_id, user_id=user_id,
system_prompt=system_prompt, system_prompt=system_prompt,

View File

@@ -60,6 +60,85 @@ def _language_instruction(context: dict[str, Any]) -> str:
f"All your output text must be written in {lang}." f"All your output text must be written in {lang}."
) )
MANIFEST_TOKEN_BUDGET = 3000 # rough budget for <linked_folder> block
def format_folder_manifest(manifest: dict | None) -> str:
"""Format a folder manifest into the <linked_folder> block.
Truncates by mtime DESC if estimated tokens exceed MANIFEST_TOKEN_BUDGET.
Returns empty string if manifest is None or has no files.
"""
if not manifest or not manifest.get("files"):
return ""
files = list(manifest["files"])
files.sort(key=lambda f: f.get("mtimeMs", 0), reverse=True)
header = (
f"<linked_folder>\npath: {manifest.get('folderPath', '?')} "
f"({len(files)} files, scanned {manifest.get('lastScannedAt', '?')})\nfiles:\n"
)
footer_template = "{} more files omitted, use read_project_folder_file to access by path\n</linked_folder>"
char_budget = MANIFEST_TOKEN_BUDGET * 4 # ~4 chars/token
body = ""
included = 0
for f in files:
line = f"- /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}\n"
if len(header) + len(body) + len(line) + len(footer_template.format(0)) > char_budget:
break
body += line
included += 1
omitted = len(files) - included
if omitted > 0:
return header + body + footer_template.format(omitted)
return header + body + "</linked_folder>"
async def _fetch_project_manifest(project_id: str) -> dict | None:
"""Fetch manifest from Electron via execute_on_client. Returns None if unlinked or error."""
from app.core.ws_context import execute_on_client
try:
result = await execute_on_client(
action="read_project_folder_manifest",
data={"projectId": project_id},
)
if not result or not result.get("folderPath"):
return None
return result
except Exception:
return None
async def build_brief_multi_project_manifest() -> str:
"""Build a compact multi-project manifest for the daily brief agent.
Calls execute_on_client('list_projects_with_folder_manifests') and keeps
the top 5 most-recently-modified files per project.
"""
try:
result = await execute_on_client(
action="list_projects_with_folder_manifests",
data={},
)
except Exception:
return ""
projects = (result or {}).get("projects") or []
if not projects:
return ""
blocks: list[str] = ["<linked_folders>"]
for p in projects:
files = sorted(p.get("files", []), key=lambda f: f.get("mtimeMs", 0), reverse=True)[:5]
if not files:
continue
blocks.append(f"project: {p.get('projectName','?')} [{p.get('projectId','?')}]")
blocks.append(f" path: {p.get('folderPath','?')} (scanned {p.get('lastScannedAt','?')})")
for f in files:
blocks.append(f" - /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}")
blocks.append("</linked_folders>")
return "\n".join(blocks)
def _datetime_context_injection(context: dict[str, Any]) -> str: def _datetime_context_injection(context: dict[str, Any]) -> str:
"""Build a comprehensive DATE CONTEXT block with pre-computed ms-epoch boundaries for common ranges.""" """Build a comprehensive DATE CONTEXT block with pre-computed ms-epoch boundaries for common ranges."""
fp = context.get("format_prefs") fp = context.get("format_prefs")
@@ -1328,9 +1407,22 @@ async def run_home_stream(
user_id: str, user_id: str,
message: str, message: str,
context: dict[str, Any], context: dict[str, Any],
project_id: str | None = None,
) -> AsyncGenerator[tuple[str, Any], None]: ) -> AsyncGenerator[tuple[str, Any], None]:
from app.agents.folder_agent import FOLDER_TOOLS
prepared_context = await _prepare_context(message, context) prepared_context = await _prepare_context(message, context)
system_prompt, langfuse_prompt = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, prepared_context) system_prompt, langfuse_prompt = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, prepared_context)
manifest_block = ""
if project_id:
manifest = await _fetch_project_manifest(project_id)
manifest_block = format_folder_manifest(manifest)
system_prompt = system_prompt + ("\n\n" + manifest_block if manifest_block else "")
trace_id = _trace_id_from_context(prepared_context)
tools = [*_all_tools_for_user(user_id, trace_id), *FOLDER_TOOLS]
text_chunks: list[str] = [] text_chunks: list[str] = []
async for event in _run_single_agent_stream( async for event in _run_single_agent_stream(
user_id=user_id, user_id=user_id,
@@ -1339,6 +1431,7 @@ async def run_home_stream(
context=prepared_context, context=prepared_context,
langfuse_prompt=langfuse_prompt, langfuse_prompt=langfuse_prompt,
agent_name="home-agent", agent_name="home-agent",
tools=tools,
conversation_history=context.get("conversation_history"), conversation_history=context.get("conversation_history"),
): ):
event_type, data = event event_type, data = event
@@ -1421,6 +1514,7 @@ async def run_task_brief_research_stream(
user_id: str, user_id: str,
task_id: str, task_id: str,
context: dict[str, Any], context: dict[str, Any],
project_id: str | None = None,
) -> AsyncGenerator[tuple[str, Any], None]: ) -> AsyncGenerator[tuple[str, Any], None]:
"""Stage-1 executive assistant: deep research for one task. """Stage-1 executive assistant: deep research for one task.
@@ -1428,8 +1522,10 @@ async def run_task_brief_research_stream(
The final concatenated text may contain a ``<canvas kind="...">...</canvas>`` block The final concatenated text may contain a ``<canvas kind="...">...</canvas>`` block
which the WS handler strips and emits as a ``canvas_draft`` mutation. which the WS handler strips and emits as a ``canvas_draft`` mutation.
""" """
from app.agents.folder_agent import FOLDER_TOOLS
prepared_context = await _prepare_context(f"task:{task_id}", context) prepared_context = await _prepare_context(f"task:{task_id}", context)
tools = _brief_research_tools(user_id, _trace_id_from_context(prepared_context)) tools = [*_brief_research_tools(user_id, _trace_id_from_context(prepared_context)), *FOLDER_TOOLS]
# Inject task_id so the agent knows what to look up first. # Inject task_id so the agent knows what to look up first.
research_message = ( research_message = (
@@ -1446,6 +1542,12 @@ async def run_task_brief_research_stream(
prepared_context, prepared_context,
) )
manifest_block = ""
if project_id:
manifest = await _fetch_project_manifest(project_id)
manifest_block = format_folder_manifest(manifest)
system_prompt = system_prompt + ("\n\n" + manifest_block if manifest_block else "")
async for event in _run_single_agent_stream( async for event in _run_single_agent_stream(
user_id=user_id, user_id=user_id,
system_prompt=system_prompt, system_prompt=system_prompt,

154
app/core/folder_indexer.py Normal file
View File

@@ -0,0 +1,154 @@
"""Per-file summarisation for project folder integration."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass
from langchain_core.messages import HumanMessage, SystemMessage
from pypdf import PdfReader
from docx import Document as DocxDocument
from app.core.langfuse_client import (
compile_prompt,
extract_usage,
get_prompt_or_fallback,
)
from app.core.llm import get_llm
_TEXT_FALLBACK = (
"You are summarising a file for an AI assistant that helps the user manage a project.\n"
"Produce a single sentence (<=30 words, <=200 chars) that captures the file's purpose "
"and most important detail.\nFile extension: {ext}\nFile name: {name}\nContent (truncated if long):\n{content}"
)
_IMAGE_FALLBACK = (
"You are summarising an image attached to a project folder.\n"
"Produce a single sentence (<=30 words, <=200 chars) describing what the image shows "
"and any obvious purpose (logo, screenshot, diagram, photo of a whiteboard, etc.)."
)
_MAX_INPUT_CHARS = 6000
@dataclass
class IndexResult:
summary: str
tokens_used: int
async def _llm_text(messages: list) -> object:
"""Make the LLM call for text summarisation.
Defined as a standalone async function so tests can patch it cleanly
without needing to mock the LLM object itself.
"""
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
return await llm.ainvoke(messages)
async def _llm_vision(messages: list) -> object:
"""Make the LLM call for vision (image) summarisation.
Accepts the message list and returns the response directly, mirroring
the ``_llm_text`` caller pattern so tests can patch it at the module level.
"""
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
return await llm.ainvoke(messages)
async def summarize_image(*, image_b64: str, mime: str) -> IndexResult:
"""Return a compact summary of an image file using vision.
Parameters
----------
image_b64:
Base64-encoded image bytes.
mime:
MIME type of the image, e.g. ``"image/png"``.
"""
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_image", _IMAGE_FALLBACK)
messages = [
SystemMessage(content=template),
HumanMessage(content=[
{"type": "text", "text": "Summarise this image."},
{"type": "image_url", "image_url": {"url": f"data:{mime};base64,{image_b64}"}},
]),
]
response = await _llm_vision(messages)
usage = extract_usage(response)
summary = (response.content or "").strip()[:500]
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
async def summarize_text(*, content: str, ext: str, name: str) -> IndexResult:
"""Return a compact summary of a text file.
Parameters
----------
content:
Raw text content of the file (will be truncated to _MAX_INPUT_CHARS).
ext:
File extension including the leading dot, e.g. ``".md"``.
name:
File name, e.g. ``"kickoff.md"``.
"""
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_text", _TEXT_FALLBACK)
truncated = content[:_MAX_INPUT_CHARS]
compiled = compile_prompt(template, prompt_obj, ext=ext, name=name, content=truncated)
messages = [
SystemMessage(content=compiled),
HumanMessage(content="Summarise this file."),
]
response = await _llm_text(messages)
usage = extract_usage(response)
summary = (response.content or "").strip()[:500]
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
def _extract_pdf_text(pdf_b64: str) -> str:
buf = io.BytesIO(base64.b64decode(pdf_b64))
reader = PdfReader(buf)
parts: list[str] = []
for page in reader.pages:
try:
parts.append(page.extract_text() or "")
except Exception:
continue
return "\n".join(parts).strip()
def _extract_docx_text(docx_b64: str) -> str:
buf = io.BytesIO(base64.b64decode(docx_b64))
doc = DocxDocument(buf)
return "\n".join(p.text for p in doc.paragraphs if p.text).strip()
async def summarize_pdf(*, pdf_b64: str, name: str) -> IndexResult:
"""Return a compact summary of a PDF file.
Parameters
----------
pdf_b64:
Base64-encoded PDF bytes.
name:
File name, e.g. ``"report.pdf"``.
"""
text = _extract_pdf_text(pdf_b64)
if not text:
return IndexResult(summary="Could not extract text", tokens_used=0)
return await summarize_text(content=text, ext=".pdf", name=name)
async def summarize_docx(*, docx_b64: str, name: str) -> IndexResult:
"""Return a compact summary of a DOCX file.
Parameters
----------
docx_b64:
Base64-encoded DOCX bytes.
name:
File name, e.g. ``"spec.docx"``.
"""
text = _extract_docx_text(docx_b64)
if not text:
return IndexResult(summary="Could not extract text", tokens_used=0)
return await summarize_text(content=text, ext=".docx", name=name)

View File

@@ -243,6 +243,7 @@ class AgentRunLog(Base):
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running") status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0) items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0) items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
errors: Mapped[list | None] = mapped_column(JSON, nullable=True) errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
started_at: Mapped[datetime] = mapped_column( started_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now() DateTime(timezone=True), nullable=False, server_default=func.now()
@@ -263,6 +264,17 @@ class AgentRunLog(Base):
) )
class MonthlyTokenUsage(Base):
__tablename__ = "monthly_token_usage"
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
)
year_month: Mapped[str] = mapped_column(String(7), primary_key=True) # 'YYYY-MM'
feature: Mapped[str] = mapped_column(String(64), primary_key=True)
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
# ── Memory models ───────────────────────────────────────────────────────────── # ── Memory models ─────────────────────────────────────────────────────────────

View File

@@ -89,6 +89,13 @@ class WsFrameType(str, Enum):
brief_request = "brief_request" brief_request = "brief_request"
# ── v6 task brief frame types ───────────────────────────────────── # ── v6 task brief frame types ─────────────────────────────────────
task_brief_request = "task_brief_request" task_brief_request = "task_brief_request"
# ── v7 folder index frame types ───────────────────────────────────
index_session_start = "index_session_start"
index_file_batch = "index_file_batch"
index_session_cancel = "index_session_cancel"
index_file_result = "index_file_result"
index_session_progress = "index_session_progress"
index_session_done = "index_session_done"
class WsToolCall(BaseModel): class WsToolCall(BaseModel):

View File

@@ -39,3 +39,5 @@ lxml>=5.0.0
PyYAML>=6.0.0 PyYAML>=6.0.0
apscheduler>=3.10.0 apscheduler>=3.10.0
ruff>=0.8.0 ruff>=0.8.0
pypdf>=4.0
python-docx>=1.1

View File

@@ -17,6 +17,8 @@ from jose import jwt
from sqlalchemy import StaticPool, event from sqlalchemy import StaticPool, event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy import select
from app.config.settings import settings from app.config.settings import settings
from app.db import Base, get_session from app.db import Base, get_session
from app.main import app from app.main import app
@@ -134,6 +136,38 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"} return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
# ── Convenience aliases and per-tier user fixtures ────────────────────
@pytest_asyncio.fixture
async def db(db_session: AsyncSession) -> AsyncSession:
"""Alias for db_session — used by folder quota tests."""
return db_session
@pytest_asyncio.fixture
async def test_user_free(db_session: AsyncSession):
"""Return the seeded free-tier User row."""
result = await db_session.execute(
select(User).where(User.id == TEST_USER_IDS["free"])
)
return result.scalar_one()
@pytest_asyncio.fixture
async def test_user_power(db_session: AsyncSession):
"""Return the seeded power-tier User row."""
result = await db_session.execute(
select(User).where(User.id == TEST_USER_IDS["power"])
)
return result.scalar_one()
@pytest.fixture
def auth_headers_free() -> dict[str, str]:
"""Authorization header for the seeded free-tier user."""
return auth_header("free")
# ── CLI options ─────────────────────────────────────────────────────── # ── CLI options ───────────────────────────────────────────────────────
def pytest_addoption(parser): def pytest_addoption(parser):

View File

@@ -0,0 +1,37 @@
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.agents.folder_agent import read_project_folder_file
pytestmark = pytest.mark.asyncio
async def test_happy_path():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "file body"}),
):
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "docs/x.md"})
assert out == "file body"
async def test_traversal_rejected():
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "../../etc/passwd"})
assert out == "Access denied"
async def test_absolute_path_rejected():
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "C:\\Windows\\foo"})
assert out == "Access denied"
async def test_missing_file():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": ""}),
):
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "ghost.md"})
assert "not found" in out.lower()

View File

@@ -0,0 +1,83 @@
"""Folder indexer LLM helpers."""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.core.folder_indexer import summarize_text, summarize_image, IndexResult
pytestmark = pytest.mark.asyncio
async def test_summarize_text_returns_summary_and_tokens():
mock_resp = AsyncMock()
mock_resp.content = "Kickoff notes covering scope and deadlines."
mock_resp.usage_metadata = {"input_tokens": 320, "output_tokens": 18, "total_tokens": 338}
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
result = await summarize_text(content="hello world", ext=".md", name="kickoff.md")
assert isinstance(result, IndexResult)
assert result.summary == "Kickoff notes covering scope and deadlines."
assert result.tokens_used == 338
async def test_summarize_text_truncates_summary_at_500_chars():
mock_resp = AsyncMock()
mock_resp.content = "x" * 1000
mock_resp.usage_metadata = {"total_tokens": 100}
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
result = await summarize_text(content="x", ext=".md", name="x.md")
assert len(result.summary) <= 500
async def test_summarize_image_uses_vision_content_blocks():
mock_resp = AsyncMock()
mock_resp.content = "Final logo on white background."
mock_resp.usage_metadata = {"total_tokens": 500}
captured = {}
async def fake_llm_vision(messages):
captured["messages"] = messages
return mock_resp
with patch("app.core.folder_indexer._llm_vision", new=fake_llm_vision):
result = await summarize_image(image_b64="iVBORw0KG", mime="image/png")
assert "Final logo" in result.summary
assert result.tokens_used == 500
# last message contains an image content block
last = captured["messages"][-1]
assert any(
isinstance(p, dict) and p.get("type") == "image_url"
for p in (last.content if isinstance(last.content, list) else [])
)
async def test_summarize_pdf_extracts_then_summarizes(monkeypatch):
# pypdf.PdfReader returns text from pages
from app.core import folder_indexer
class FakePage:
def extract_text(self): return "PDF page content with project info."
class FakeReader:
pages = [FakePage(), FakePage()]
monkeypatch.setattr(folder_indexer, "PdfReader", lambda buf: FakeReader())
mock_resp = AsyncMock(); mock_resp.content = "Project info doc."; mock_resp.usage_metadata = {"total_tokens": 50}
async def fake_llm(messages): return mock_resp
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
result = await folder_indexer.summarize_pdf(pdf_b64="SGVsbG8=", name="doc.pdf")
assert "Project info" in result.summary
assert result.tokens_used == 50
async def test_summarize_docx_extracts_then_summarizes(monkeypatch):
from app.core import folder_indexer
class FakePara:
def __init__(self, t): self.text = t
class FakeDoc:
paragraphs = [FakePara("Heading"), FakePara("Body paragraph one.")]
monkeypatch.setattr(folder_indexer, "DocxDocument", lambda buf: FakeDoc())
mock_resp = AsyncMock(); mock_resp.content = "Heading and body."; mock_resp.usage_metadata = {"total_tokens": 30}
async def fake_llm(messages): return mock_resp
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
result = await folder_indexer.summarize_docx(docx_b64="UEsDBBQ=", name="doc.docx")
assert result.summary == "Heading and body."

View File

@@ -0,0 +1,94 @@
"""Folder quota helpers."""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from sqlalchemy import select
from app.billing.quota import (
check_folder_quota,
add_token_usage,
QuotaExceeded,
)
from app.models import MonthlyTokenUsage
pytestmark = pytest.mark.asyncio
async def test_check_folder_quota_free_rejects_above_file_cap(db, test_user_free):
with pytest.raises(QuotaExceeded) as exc:
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=500, db=db
)
assert exc.value.reason == "max_files"
async def test_check_folder_quota_free_passes_under_cap(db, test_user_free):
# No raise
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=50, db=db
)
async def test_check_folder_quota_rejects_when_monthly_exhausted(db, test_user_free):
ym = datetime.now(timezone.utc).strftime("%Y-%m")
db.add(MonthlyTokenUsage(
user_id=test_user_free.id, year_month=ym, feature="folder_index", tokens_used=100_000
))
await db.commit()
with pytest.raises(QuotaExceeded) as exc:
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=10, db=db
)
assert exc.value.reason == "monthly_tokens"
async def test_check_folder_quota_power_unlimited(db, test_user_power):
await check_folder_quota(
user_id=test_user_power.id, tier="power", estimated_files=999_999, db=db
)
async def test_add_token_usage_atomic_increment(db, test_user_free):
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=1500, db=db)
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=2500, db=db)
ym = datetime.now(timezone.utc).strftime("%Y-%m")
row = (await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == test_user_free.id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == "folder_index",
)
)).scalar_one()
assert row.tokens_used == 4000
async def test_add_token_usage_returns_exhausted_when_over_cap(db, test_user_free):
result = await add_token_usage(
user_id=test_user_free.id, feature="folder_index", tokens=150_000, db=db, cap=100_000
)
assert result.exhausted is True
assert result.tokens_used == 150_000
def test_quota_check_endpoint_rejects(client, auth_headers_free):
res = client.post(
"/api/v1/billing/quota/check",
json={"feature": "folder_index", "estimated_files": 500},
headers=auth_headers_free,
)
assert res.status_code == 402
body = res.json()
assert body["detail"]["reason"] == "max_files"
def test_quota_check_endpoint_passes(client, auth_headers_free):
res = client.post(
"/api/v1/billing/quota/check",
json={"feature": "folder_index", "estimated_files": 50},
headers=auth_headers_free,
)
assert res.status_code == 200
assert res.json() == {"ok": True}

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.core.deep_agent import format_folder_manifest, MANIFEST_TOKEN_BUDGET
pytestmark = pytest.mark.asyncio
def test_format_folder_manifest_basic():
manifest = {
"folderPath": "D:\\Acme",
"lastScannedAt": "2h ago",
"files": [
{"relPath": "briefs/kickoff.md", "kind": "text", "summary": "Kickoff notes; scope and deadlines."},
{"relPath": "logos/logo-v3.png", "kind": "image", "summary": "Final logo on white."},
],
}
out = format_folder_manifest(manifest)
assert "<linked_folder>" in out
assert "/briefs/kickoff.md" in out or "briefs/kickoff.md" in out
assert "[text]" in out
assert "[image]" in out
def test_format_folder_manifest_truncates_past_budget():
files = [
{"relPath": f"f{i}.md", "kind": "text", "summary": "x" * 100, "mtimeMs": i}
for i in range(2000)
]
out = format_folder_manifest({"folderPath": "p", "lastScannedAt": "now", "files": files})
assert "more files omitted" in out
# Rough token check
assert len(out) // 4 < MANIFEST_TOKEN_BUDGET + 200
def test_format_folder_manifest_null_returns_empty():
assert format_folder_manifest(None) == ""
assert format_folder_manifest({"files": []}) == ""
async def test_brief_multi_project_manifest_top_5_per_project():
fake_response = [
{
"projectId": "p1", "projectName": "Acme", "folderPath": "/a",
"lastScannedAt": "now",
"files": [
{"relPath": f"f{i}.md", "kind": "text", "summary": "s", "mtimeMs": i}
for i in range(10)
],
},
{
"projectId": "p2", "projectName": "Beta", "folderPath": "/b",
"lastScannedAt": "now",
"files": [{"relPath": "x.md", "kind": "text", "summary": "s", "mtimeMs": 1}],
},
]
with patch(
"app.core.deep_agent.execute_on_client",
new=AsyncMock(return_value={"projects": fake_response}),
):
from app.core.deep_agent import build_brief_multi_project_manifest
out = await build_brief_multi_project_manifest()
# Project 1 has 10 files, only top 5 by mtimeMs should appear
assert out.count("[p1]") <= 5
# Project 2 has 1 file, must appear
assert "[p2]" in out or "Beta" in out

View File

@@ -0,0 +1,196 @@
"""Tests for WS folder index_session handlers (Task 9).
Tests the three handler functions directly with a minimal fake WebSocket so
no real WS connection or LLM call is made.
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
import pytest_asyncio
from app.api.routes.device_ws import (
_handle_index_session_start,
_handle_index_file_batch,
_handle_index_session_cancel,
_index_sessions,
)
from app.billing.quota import add_token_usage
from app.core.folder_indexer import IndexResult
from app.models import MonthlyTokenUsage
from app.schemas import WsFrameType
from tests.conftest import TEST_USER_IDS
pytestmark = pytest.mark.asyncio
USER_ID = TEST_USER_IDS["free"]
POWER_USER_ID = TEST_USER_IDS["power"]
# ── Fake WebSocket ────────────────────────────────────────────────────
class _FakeWebSocket:
"""Minimal WebSocket stand-in that records send_text calls."""
def __init__(self) -> None:
self.sent: list[dict] = []
async def send_text(self, text: str) -> None:
self.sent.append(json.loads(text))
def sent_types(self) -> list[str]:
return [f["type"] for f in self.sent]
# ── Helpers ───────────────────────────────────────────────────────────
def _make_session_id() -> str:
import uuid
return str(uuid.uuid4())
def _fake_summarize_text_factory(summary: str = "A test summary.", tokens: int = 100):
"""Return an AsyncMock that resolves to a fixed IndexResult."""
async def _fake(**kwargs) -> IndexResult:
return IndexResult(summary=summary, tokens_used=tokens)
return _fake
# ── Fixtures ──────────────────────────────────────────────────────────
@pytest_asyncio.fixture(autouse=True)
async def _clean_sessions():
"""Ensure _index_sessions is empty before and after each test."""
_index_sessions.clear()
yield
_index_sessions.clear()
# ── Tests ─────────────────────────────────────────────────────────────
async def test_index_session_happy_path(db_session):
"""start + batch of 2 text files → 2 index_file_result + 1 progress + 1 done(completed)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
# Register session.
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"projectId": "proj-1",
"totalFiles": 2,
})
# Verify session was registered.
assert session_id in _index_sessions
assert _index_sessions[session_id]["total"] == 2
assert _index_sessions[session_id]["processed"] == 0
# No response frames expected for session_start.
assert ws.sent == []
# Send batch of 2 text files — patch summarize_text so no LLM call needed.
with patch(
"app.api.routes.device_ws._handle_index_file_batch.__globals__",
# We patch the module-level function in folder_indexer instead:
) if False else patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory()):
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
# Wire db_session into the context manager.
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_async_session.return_value = mock_cm
await _handle_index_file_batch(ws, USER_ID, {
"sessionId": session_id,
"files": [
{"relPath": "README.md", "kind": "text", "content": "hello", "ext": ".md"},
{"relPath": "notes.txt", "kind": "text", "content": "world", "ext": ".txt"},
],
})
types = ws.sent_types()
# Expect 2 file results + 1 progress + 1 done(completed).
assert types.count(WsFrameType.index_file_result) == 2
assert types.count(WsFrameType.index_session_progress) == 1
assert types.count(WsFrameType.index_session_done) == 1
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "completed"
progress_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_progress)
assert progress_frame["processed"] == 2
assert progress_frame["total"] == 2
# Verify session cleaned up.
assert session_id not in _index_sessions
async def test_index_session_cancel(db_session):
"""start then cancel → index_session_done(cancelled)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"totalFiles": 5,
})
assert session_id in _index_sessions
await _handle_index_session_cancel(ws, {"sessionId": session_id})
types = ws.sent_types()
assert WsFrameType.index_session_done in types
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "cancelled"
# Session should be cleaned up.
assert session_id not in _index_sessions
async def test_index_session_quota_exceeded(db_session):
"""Pre-fill usage to cap → batch one file → index_session_done(quota_exceeded)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
# Pre-fill monthly token usage to the free-tier cap (100_000).
ym = datetime.now(timezone.utc).strftime("%Y-%m")
db_session.add(MonthlyTokenUsage(
user_id=USER_ID,
year_month=ym,
feature="folder_index",
tokens_used=100_000, # free tier cap exactly
))
await db_session.commit()
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"totalFiles": 1,
})
with patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory(tokens=1)):
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_async_session.return_value = mock_cm
await _handle_index_file_batch(ws, USER_ID, {
"sessionId": session_id,
"files": [
{"relPath": "file.md", "kind": "text", "content": "content", "ext": ".md"},
],
})
types = ws.sent_types()
# Should have 1 file result (success) then done(quota_exceeded).
assert WsFrameType.index_file_result in types
assert WsFrameType.index_session_done in types
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "quota_exceeded"
# Session should be cleaned up.
assert session_id not in _index_sessions