Compare commits
21 Commits
6f4c68b359
...
feat/proje
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc0e258e8c | ||
|
|
12e203e63d | ||
|
|
ffcd7390f0 | ||
|
|
91e880f9d4 | ||
|
|
7d47ca54be | ||
|
|
956fa88853 | ||
|
|
fb2f59ccea | ||
|
|
56dbb7f4cd | ||
|
|
506f517851 | ||
|
|
520c186991 | ||
|
|
582bf27deb | ||
|
|
2aeb453229 | ||
|
|
b7a4edac90 | ||
|
|
822b4cd8b1 | ||
|
|
ab24fc4c91 | ||
|
|
a98e99f7a2 | ||
|
|
a0ff285bcd | ||
|
|
177c1a87dd | ||
|
|
441a4ea05c | ||
|
|
a693a64bf5 | ||
|
|
67562b8092 |
@@ -56,6 +56,10 @@ LLM_MODEL_CLOUD_PROCESSOR=
|
||||
# A small model (e.g. gpt-4o-mini) is sufficient.
|
||||
# LLM_MODEL_BRIEF_AGENT=
|
||||
|
||||
# Task-brief-agent — per-task deep research (Stage 1 executive assistant).
|
||||
# Needs tool-use + reasoning; a capable model recommended (e.g. gpt-4o, gemini-2.5-flash).
|
||||
# LLM_MODEL_TASK_BRIEF_AGENT=
|
||||
|
||||
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
|
||||
LLM_MODEL_SETUP_AGENT=
|
||||
|
||||
|
||||
46
alembic/versions/d6e3f4a5b6c7_folder_index_tables.py
Normal file
46
alembic/versions/d6e3f4a5b6c7_folder_index_tables.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Add token tracking columns for folder integration.
|
||||
|
||||
Revision ID: d6e3f4a5b6c7
|
||||
Revises: 006
|
||||
Create Date: 2026-05-11 00:00:00.000000
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d6e3f4a5b6c7"
|
||||
down_revision: Union[str, None] = "006"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"agent_run_logs",
|
||||
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
op.create_table(
|
||||
"monthly_token_usage",
|
||||
sa.Column("user_id", UUID(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("year_month", sa.String(7), nullable=False),
|
||||
sa.Column("feature", sa.String(64), nullable=False),
|
||||
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.PrimaryKeyConstraint("user_id", "year_month", "feature"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_monthly_token_usage_user_month",
|
||||
"monthly_token_usage",
|
||||
["user_id", "year_month"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_monthly_token_usage_user_month", table_name="monthly_token_usage")
|
||||
op.drop_table("monthly_token_usage")
|
||||
op.drop_column("agent_run_logs", "tokens_used")
|
||||
52
app/agents/client_agent.py
Normal file
52
app/agents/client_agent.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Client agent — read-only tools for the clients table."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
|
||||
@tool
|
||||
async def list_clients(search: str = "", limit: int = 20) -> str:
|
||||
"""List clients, optionally filtered by a name/email substring search.
|
||||
|
||||
search: optional substring to match against client name or email.
|
||||
limit: max rows to return (default 20).
|
||||
"""
|
||||
filters: dict[str, Any] = {"limit": limit}
|
||||
if search:
|
||||
filters["search"] = search
|
||||
|
||||
result = await execute_on_client(action="select", table="clients", filters=filters)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No clients found."
|
||||
lines = [
|
||||
f"- {r.get('name', '?')} (id: {r.get('id')}, email: {r.get('email', '')}, "
|
||||
f"company: {r.get('company', '')})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} client(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def get_client(id: str) -> str:
|
||||
"""Get full details for one client by UUID.
|
||||
|
||||
id: the client's UUID.
|
||||
"""
|
||||
if not id:
|
||||
return "Client id is required."
|
||||
|
||||
result = await execute_on_client(action="get", table="clients", data={"id": id})
|
||||
row = result.get("row") or result.get("rows", [None])[0] if result else None
|
||||
if not row:
|
||||
return f"Client '{id}' not found."
|
||||
return f"Client details:\n{json.dumps(row, ensure_ascii=False, indent=2)}"
|
||||
|
||||
|
||||
CLIENT_TOOLS: list[Any] = [list_clients, get_client]
|
||||
168
app/agents/folder_agent.py
Normal file
168
app/agents/folder_agent.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Scoped file-read and search tools for the project folder feature."""
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.folder_indexer import _extract_docx_text, _extract_pdf_text
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
# Cap returned slice size to keep tool output under control.
|
||||
_MAX_RETURN_CHARS = 50_000
|
||||
_MAX_SEARCH_MATCHES = 20
|
||||
|
||||
|
||||
def _is_unsafe_path(rel: str) -> bool:
|
||||
if not rel:
|
||||
return True
|
||||
norm = rel.replace("\\", "/")
|
||||
if norm.startswith("/"):
|
||||
return True
|
||||
# Windows drive letter
|
||||
if len(rel) >= 2 and rel[1] == ":":
|
||||
return True
|
||||
parts = norm.split("/")
|
||||
return ".." in parts
|
||||
|
||||
|
||||
async def _fetch_file(project_id: str, relative_path: str, offset: int, length: int) -> dict:
|
||||
"""Return the raw Electron tool_result dict for a file read."""
|
||||
return await execute_on_client(
|
||||
action="read_project_folder_file",
|
||||
data={
|
||||
"projectId": project_id,
|
||||
"relativePath": relative_path,
|
||||
"offset": offset,
|
||||
"length": length,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _decode(result: dict) -> tuple[str, str, int]:
|
||||
"""Decode a tool_result into (text, kind, total_size). For pdf/docx,
|
||||
extracts text from base64. For images, returns a placeholder string.
|
||||
For text, content is already a sliced utf-8 string.
|
||||
"""
|
||||
kind = result.get("kind", "text")
|
||||
content = result.get("content", "") or ""
|
||||
total = int(result.get("totalSize", 0) or 0)
|
||||
if kind == "image":
|
||||
return ("[Image file — cannot be navigated as text. See manifest summary.]", kind, total)
|
||||
if kind == "pdf":
|
||||
return (_extract_pdf_text(content), kind, total)
|
||||
if kind == "docx":
|
||||
return (_extract_docx_text(content), kind, total)
|
||||
return (content, kind, total)
|
||||
|
||||
|
||||
@tool
|
||||
async def read_project_folder_file(
|
||||
project_id: str,
|
||||
relative_path: str,
|
||||
offset: int = 0,
|
||||
length: int = _MAX_RETURN_CHARS,
|
||||
) -> str:
|
||||
"""Read a slice of a file inside the project's linked folder.
|
||||
|
||||
Args:
|
||||
project_id: project ID.
|
||||
relative_path: path relative to the linked folder root.
|
||||
offset: char offset to start reading from (0 = beginning).
|
||||
length: max chars to return. Default 50000. Use smaller values to save tokens.
|
||||
|
||||
Returns text content slice with a header showing position. Header tells you
|
||||
when more content is available; call again with the suggested next offset.
|
||||
|
||||
For PDF / DOCX files the backend extracts text first, then applies offset/length
|
||||
on the extracted text. For images returns a placeholder; navigate with the
|
||||
manifest summary instead.
|
||||
"""
|
||||
if _is_unsafe_path(relative_path):
|
||||
return "Access denied"
|
||||
|
||||
result = await _fetch_file(project_id, relative_path, offset, length)
|
||||
text, kind, total_size = _decode(result)
|
||||
|
||||
if not text and kind in ("missing", "error"):
|
||||
return f"File not found or unreadable: {relative_path}"
|
||||
|
||||
if kind in ("pdf", "docx"):
|
||||
# Backend extracted full text — apply offset/length on chars.
|
||||
sliced = text[offset:offset + length]
|
||||
slice_end = min(offset + length, len(text))
|
||||
header = (
|
||||
f"[file={relative_path} kind={kind} offset={offset} end={slice_end} "
|
||||
f"totalChars={len(text)}]"
|
||||
)
|
||||
if slice_end < len(text):
|
||||
header += f"\n[More content available — call again with offset={slice_end}.]"
|
||||
return header + "\n" + sliced
|
||||
|
||||
if kind == "text":
|
||||
slice_end = offset + len(text)
|
||||
header = (
|
||||
f"[file={relative_path} kind=text offset={offset} end={slice_end} "
|
||||
f"totalBytes={total_size}]"
|
||||
)
|
||||
if slice_end < total_size:
|
||||
header += f"\n[More content available — call again with offset={slice_end}.]"
|
||||
return header + "\n" + text
|
||||
|
||||
# image or unknown
|
||||
return text
|
||||
|
||||
|
||||
@tool
|
||||
async def search_project_folder_file(
|
||||
project_id: str,
|
||||
relative_path: str,
|
||||
query: str,
|
||||
context_lines: int = 3,
|
||||
) -> str:
|
||||
"""Search a project folder file for a query string (case-insensitive substring).
|
||||
|
||||
Args:
|
||||
project_id: project ID.
|
||||
relative_path: path relative to the linked folder root.
|
||||
query: text to search for.
|
||||
context_lines: number of lines of context around each match (default 3).
|
||||
|
||||
Returns matching line ranges with surrounding context and 1-based line numbers.
|
||||
Capped at 20 matches; if more exist the header shows the total.
|
||||
|
||||
Works on text, code, markdown, PDF (extracted), and DOCX (extracted).
|
||||
Images and binary files are not searchable.
|
||||
"""
|
||||
if _is_unsafe_path(relative_path):
|
||||
return "Access denied"
|
||||
if not query:
|
||||
return "Empty query."
|
||||
|
||||
# For text we still need full file; pass length=very large.
|
||||
result = await _fetch_file(project_id, relative_path, offset=0, length=10_000_000)
|
||||
text, kind, _ = _decode(result)
|
||||
|
||||
if not text and kind in ("missing", "error"):
|
||||
return f"File not found or unreadable: {relative_path}"
|
||||
if kind == "image":
|
||||
return "Cannot search inside images."
|
||||
|
||||
lines = text.splitlines()
|
||||
q = query.lower()
|
||||
matches = [i for i, line in enumerate(lines) if q in line.lower()]
|
||||
if not matches:
|
||||
return f"No matches for '{query}' in {relative_path}."
|
||||
|
||||
shown = matches[:_MAX_SEARCH_MATCHES]
|
||||
snippets: list[str] = []
|
||||
for i in shown:
|
||||
start = max(0, i - context_lines)
|
||||
end = min(len(lines), i + context_lines + 1)
|
||||
block = "\n".join(f"{n + 1:5d}: {lines[n]}" for n in range(start, end))
|
||||
snippets.append(block)
|
||||
|
||||
header = f"[file={relative_path} matches={len(matches)} showing={len(shown)} query='{query}']"
|
||||
body = "\n---\n".join(snippets)
|
||||
return header + "\n" + body
|
||||
|
||||
|
||||
FOLDER_TOOLS = [read_project_folder_file, search_project_folder_file]
|
||||
63
app/agents/relations_agent.py
Normal file
63
app/agents/relations_agent.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Relations agent — read-only tool wrapping MemoryMiddleware.query_relations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import async_session
|
||||
|
||||
# Injected at tool-factory time by _brief_research_tools(); not a module-level global.
|
||||
# Each tool closure captures the user_id bound at factory time.
|
||||
|
||||
|
||||
def make_query_relations_tool(user_id: str, trace_id: str | None = None) -> Any:
|
||||
"""Return a query_relations tool bound to *user_id*."""
|
||||
|
||||
@tool
|
||||
async def query_relations(
|
||||
subject_label: str = "",
|
||||
predicate: str = "",
|
||||
object_label: str = "",
|
||||
limit: int = 10,
|
||||
) -> str:
|
||||
"""Query the relational memory graph for entity relationships.
|
||||
|
||||
Returns rows where subject ↔ predicate ↔ object match the given filters.
|
||||
All parameters are optional — omit to retrieve all relations up to limit.
|
||||
|
||||
subject_label: entity label on the left side (e.g. a client name, "Acme Corp").
|
||||
predicate: relationship type (e.g. "mentioned_in", "works_at", "related_to").
|
||||
object_label: entity label on the right side (e.g. a project name, "Website Redesign").
|
||||
limit: max rows to return (default 10).
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(
|
||||
"relations_agent: query_relations trace=%s user=%s subject=%r predicate=%r object=%r",
|
||||
trace_id or "-", user_id, subject_label, predicate, object_label,
|
||||
)
|
||||
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
rows = await memory.query_relations(
|
||||
user_id=user_id,
|
||||
subject=subject_label or None,
|
||||
predicate=predicate or None,
|
||||
object_=object_label or None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if not rows:
|
||||
return "No relational memory entries found for the given filters."
|
||||
|
||||
lines = [
|
||||
f"- {r.subject_label} —[{r.predicate}]→ {r.object_label}"
|
||||
+ (f" (confidence: {r.confidence:.2f})" if r.confidence is not None else "")
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} relation(s):\n" + "\n".join(lines)
|
||||
|
||||
return query_relations
|
||||
@@ -9,7 +9,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Request, status
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -96,3 +96,37 @@ async def list_invoices(
|
||||
"""
|
||||
invoices = await stripe_service.list_invoices(current_user.id, db)
|
||||
return invoices
|
||||
|
||||
|
||||
# ── Quota check ────────────────────────────────────────────────────────
|
||||
|
||||
from app.billing.quota import check_folder_quota, QuotaExceeded # noqa: E402
|
||||
|
||||
|
||||
class QuotaCheckRequest(BaseModel):
|
||||
feature: str
|
||||
estimated_files: int
|
||||
|
||||
|
||||
@router.post("/quota/check")
|
||||
async def quota_check(
|
||||
payload: QuotaCheckRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict:
|
||||
"""Pre-flight folder quota check. 402 if tier limits would be exceeded."""
|
||||
if payload.feature != "folder_index":
|
||||
raise HTTPException(status_code=400, detail="Unknown feature")
|
||||
try:
|
||||
await check_folder_quota(
|
||||
user_id=current_user.id,
|
||||
tier=current_user.tier,
|
||||
estimated_files=payload.estimated_files,
|
||||
db=db,
|
||||
)
|
||||
except QuotaExceeded as exc:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={"reason": exc.reason, "message": str(exc)},
|
||||
)
|
||||
return {"ok": True}
|
||||
|
||||
@@ -43,7 +43,8 @@ from app.api.routes.agent_setup import handle_journey_message, handle_journey_st
|
||||
from app.config.settings import settings
|
||||
from app.core.agent_runner import trigger_pending_runs
|
||||
from app.core.brief_agent import run_home_brief, run_project_brief
|
||||
from app.core.deep_agent import run_floating_stream, run_home_stream
|
||||
from app.core.deep_agent import run_floating_stream, run_home_stream, run_task_brief_research_stream
|
||||
from app.core.output_formatter import extract_canvas_block
|
||||
from app.core.device_manager import device_manager
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.core.output_formatter import StreamFormatter
|
||||
@@ -56,6 +57,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
||||
|
||||
# ── v7 folder index session state ─────────────────────────────────────
|
||||
# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled }
|
||||
_index_sessions: dict[str, dict] = {}
|
||||
|
||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
||||
|
||||
@@ -164,6 +169,11 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
||||
_handle_brief_request(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.task_brief_request:
|
||||
asyncio.create_task(
|
||||
_handle_task_brief_request(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.journey_start:
|
||||
asyncio.create_task(
|
||||
_handle_journey_start(websocket, user_id, frame)
|
||||
@@ -174,6 +184,19 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
||||
_handle_journey_message(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.index_session_start:
|
||||
asyncio.create_task(
|
||||
_handle_index_session_start(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.index_file_batch:
|
||||
asyncio.create_task(
|
||||
_handle_index_file_batch(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.index_session_cancel:
|
||||
await _handle_index_session_cancel(websocket, frame)
|
||||
|
||||
elif frame_type == "pong":
|
||||
# Heartbeat ack — nothing to do, connection is alive.
|
||||
pass
|
||||
@@ -205,11 +228,13 @@ async def _handle_home_request(
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
message: str = frame.get("message", "")
|
||||
session_id: str = frame.get("session_id") or str(uuid4())
|
||||
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
|
||||
logger.info(
|
||||
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
|
||||
"device_ws: home_request_start user=%s req=%s session=%s project=%s msg=%s",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
project_id,
|
||||
message[:200],
|
||||
)
|
||||
|
||||
@@ -234,7 +259,7 @@ async def _handle_home_request(
|
||||
set_client_executor(executor)
|
||||
response_chunks: list[str] = []
|
||||
try:
|
||||
event_stream = run_home_stream(user_id, message, context)
|
||||
event_stream = run_home_stream(user_id, message, context, project_id=project_id)
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
@@ -415,6 +440,98 @@ async def _handle_brief_request(
|
||||
)
|
||||
|
||||
|
||||
# ── v6 Task Brief Handler ────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _handle_task_brief_request(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a task_brief_request frame — Stage-1 executive assistant deep research.
|
||||
|
||||
Streams the briefing markdown back to the client.
|
||||
On stream_end, emits a ``canvas_draft`` mutation if the agent produced one.
|
||||
"""
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
session_id = frame.get("session_id") or str(uuid4())
|
||||
task_id: str = frame.get("task_id") or frame.get("taskId") or ""
|
||||
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
|
||||
|
||||
logger.info(
|
||||
"device_ws: task_brief_request_start user=%s req=%s task=%s project=%s [cache_miss]",
|
||||
user_id, request_id, task_id, project_id,
|
||||
)
|
||||
|
||||
if not task_id:
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, error="task_id is required").model_dump_json()
|
||||
)
|
||||
return
|
||||
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
user_id,
|
||||
f"task brief: {task_id}",
|
||||
trace_id=request_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
"format_prefs": frame.get("format_prefs"),
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
response_chunks: list[str] = []
|
||||
|
||||
try:
|
||||
event_stream = run_task_brief_research_stream(user_id, task_id, context, project_id=project_id)
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
elif ws_frame.type == "stream_start":
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
# stream_end is emitted below with mutations — skip formatter's version
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: task_brief_request failed user=%s req=%s task=%s: %s",
|
||||
user_id, request_id, task_id, exc,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
||||
)
|
||||
return
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
# Extract canvas block then emit stream_end with optional mutations.
|
||||
full_response = "".join(response_chunks)
|
||||
_visible, canvas_content, canvas_kind = extract_canvas_block(full_response)
|
||||
|
||||
mutations: list[dict] = []
|
||||
if canvas_content:
|
||||
mutations.append({
|
||||
"type": "canvas_draft",
|
||||
"content": canvas_content,
|
||||
"kind": canvas_kind,
|
||||
})
|
||||
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, mutations=mutations or None).model_dump_json()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"device_ws: task_brief_request_end user=%s req=%s task=%s response_chars=%d canvas=%s",
|
||||
user_id, request_id, task_id, len(full_response), canvas_kind or "none",
|
||||
)
|
||||
|
||||
|
||||
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -472,6 +589,174 @@ async def _handle_journey_message(
|
||||
clear_client_executor()
|
||||
|
||||
|
||||
# ── v7 Folder Index Handlers ──────────────────────────────────────────
|
||||
|
||||
|
||||
async def _handle_index_session_start(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Register a new folder index session. No response sent — client is declaring intent."""
|
||||
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||
project_id: str | None = frame.get("projectId") or frame.get("project_id")
|
||||
total: int = int(frame.get("totalFiles") or frame.get("total_files") or 0)
|
||||
|
||||
if not session_id:
|
||||
logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id)
|
||||
return
|
||||
|
||||
_index_sessions[session_id] = {
|
||||
"user_id": user_id,
|
||||
"project_id": project_id,
|
||||
"processed": 0,
|
||||
"total": total,
|
||||
"cancelled": False,
|
||||
}
|
||||
logger.info(
|
||||
"device_ws: index_session_start user=%s session=%s project=%s total=%d",
|
||||
user_id, session_id, project_id, total,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_index_session_cancel(
|
||||
websocket: WebSocket,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Mark a session as cancelled and emit index_session_done(cancelled)."""
|
||||
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||
session = _index_sessions.get(session_id)
|
||||
if session:
|
||||
session["cancelled"] = True
|
||||
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_done,
|
||||
"sessionId": session_id,
|
||||
"status": "cancelled",
|
||||
}))
|
||||
_index_sessions.pop(session_id, None)
|
||||
logger.info("device_ws: index_session_cancel session=%s", session_id)
|
||||
|
||||
|
||||
async def _handle_index_file_batch(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Process a batch of files for an index session, streaming results back."""
|
||||
# Lazy imports to avoid heavy load at module startup.
|
||||
from app.core.folder_indexer import ( # noqa: PLC0415
|
||||
summarize_image,
|
||||
summarize_pdf,
|
||||
summarize_docx,
|
||||
summarize_text,
|
||||
)
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
from app.billing.quota import add_token_usage # noqa: PLC0415
|
||||
|
||||
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||
files: list[dict] = frame.get("files", [])
|
||||
|
||||
session = _index_sessions.get(session_id)
|
||||
if not session or session.get("cancelled"):
|
||||
return
|
||||
|
||||
async with async_session() as db:
|
||||
tier = await tier_manager.get_tier(user_id, db)
|
||||
raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
||||
cap: int | None = None if raw_cap == -1 else raw_cap
|
||||
|
||||
for file_info in files:
|
||||
if session.get("cancelled"):
|
||||
return
|
||||
|
||||
# Electron's toSnakeCase converts payload keys, so accept both forms.
|
||||
rel_path: str = file_info.get("relPath") or file_info.get("rel_path") or ""
|
||||
kind: str = file_info.get("kind") or "text"
|
||||
content: str = file_info.get("content") or ""
|
||||
ext: str = file_info.get("ext") or ""
|
||||
mime: str = file_info.get("mime") or "application/octet-stream"
|
||||
name: str = rel_path.split("/")[-1] or rel_path
|
||||
|
||||
try:
|
||||
if kind == "image":
|
||||
res = await summarize_image(image_b64=content, mime=mime)
|
||||
elif kind == "pdf":
|
||||
res = await summarize_pdf(pdf_b64=content, name=name)
|
||||
elif kind == "docx":
|
||||
res = await summarize_docx(docx_b64=content, name=name)
|
||||
else:
|
||||
res = await summarize_text(content=content, ext=ext, name=name)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"device_ws: index_file_batch summarize failed session=%s path=%s: %s",
|
||||
session_id, rel_path, exc,
|
||||
)
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_file_result,
|
||||
"sessionId": session_id,
|
||||
"relPath": rel_path,
|
||||
"summary": None,
|
||||
"tokensUsed": 0,
|
||||
"error": str(exc),
|
||||
}))
|
||||
session["processed"] += 1
|
||||
continue
|
||||
|
||||
# Account for token usage and check cap.
|
||||
usage = await add_token_usage(
|
||||
user_id=user_id,
|
||||
feature="folder_index",
|
||||
tokens=res.tokens_used,
|
||||
db=db,
|
||||
cap=cap,
|
||||
)
|
||||
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_file_result,
|
||||
"sessionId": session_id,
|
||||
"relPath": rel_path,
|
||||
"summary": res.summary,
|
||||
"tokensUsed": res.tokens_used,
|
||||
}))
|
||||
session["processed"] += 1
|
||||
|
||||
if usage.exhausted:
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_done,
|
||||
"sessionId": session_id,
|
||||
"status": "quota_exceeded",
|
||||
}))
|
||||
_index_sessions.pop(session_id, None)
|
||||
logger.info(
|
||||
"device_ws: index_session quota_exceeded user=%s session=%s",
|
||||
user_id, session_id,
|
||||
)
|
||||
return
|
||||
|
||||
# After processing the batch, emit progress.
|
||||
processed = session["processed"]
|
||||
total = session["total"]
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_progress,
|
||||
"sessionId": session_id,
|
||||
"processed": processed,
|
||||
"total": total,
|
||||
}))
|
||||
|
||||
if processed >= total:
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_done,
|
||||
"sessionId": session_id,
|
||||
"status": "completed",
|
||||
}))
|
||||
_index_sessions.pop(session_id, None)
|
||||
logger.info(
|
||||
"device_ws: index_session_done completed user=%s session=%s processed=%d",
|
||||
user_id, session_id, processed,
|
||||
)
|
||||
|
||||
|
||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||
|
||||
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||
|
||||
139
app/billing/quota.py
Normal file
139
app/billing/quota.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Quota checks and atomic token-usage accounting for folder integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.billing.tier_manager import TierManager
|
||||
from app.models import MonthlyTokenUsage
|
||||
from app.schemas import BillingTier
|
||||
|
||||
|
||||
class QuotaExceeded(Exception):
|
||||
"""Raised when a folder operation cannot proceed under the user's tier."""
|
||||
|
||||
def __init__(self, reason: str, message: str) -> None:
|
||||
super().__init__(message)
|
||||
self.reason = reason # "max_files" | "monthly_tokens"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenUsageResult:
|
||||
tokens_used: int
|
||||
exhausted: bool
|
||||
|
||||
|
||||
def _current_year_month() -> str:
|
||||
return datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
|
||||
|
||||
_tier_manager = TierManager()
|
||||
|
||||
|
||||
async def check_folder_quota(
|
||||
*,
|
||||
user_id: str,
|
||||
tier: BillingTier,
|
||||
estimated_files: int,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Raise QuotaExceeded if folder_max_files or folder_monthly_tokens
|
||||
would be violated. -1 in either feature means unlimited."""
|
||||
max_files = _tier_manager.get_feature_value(tier, "folder_max_files")
|
||||
if max_files != -1 and estimated_files > max_files:
|
||||
raise QuotaExceeded(
|
||||
"max_files",
|
||||
f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.",
|
||||
)
|
||||
|
||||
cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
||||
if cap == -1:
|
||||
return
|
||||
ym = _current_year_month()
|
||||
row = (
|
||||
await db.execute(
|
||||
select(MonthlyTokenUsage).where(
|
||||
MonthlyTokenUsage.user_id == user_id,
|
||||
MonthlyTokenUsage.year_month == ym,
|
||||
MonthlyTokenUsage.feature == "folder_index",
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
used = row.tokens_used if row else 0
|
||||
if used >= cap:
|
||||
raise QuotaExceeded(
|
||||
"monthly_tokens",
|
||||
f"Monthly token budget exhausted ({used}/{cap}); resets next month.",
|
||||
)
|
||||
|
||||
|
||||
async def add_token_usage(
|
||||
*,
|
||||
user_id: str,
|
||||
feature: str,
|
||||
tokens: int,
|
||||
db: AsyncSession,
|
||||
cap: int | None = None,
|
||||
) -> TokenUsageResult:
|
||||
"""Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature).
|
||||
|
||||
Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls
|
||||
back to a read-then-write on other engines (e.g. aiosqlite in tests).
|
||||
Returns post-update total and whether cap is exhausted.
|
||||
"""
|
||||
ym = _current_year_month()
|
||||
|
||||
# Detect dialect to choose between native upsert and portable fallback.
|
||||
dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr]
|
||||
|
||||
if dialect_name == "postgresql":
|
||||
# Native atomic upsert — production path.
|
||||
stmt = (
|
||||
pg_insert(MonthlyTokenUsage)
|
||||
.values(
|
||||
user_id=user_id,
|
||||
year_month=ym,
|
||||
feature=feature,
|
||||
tokens_used=tokens,
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["user_id", "year_month", "feature"],
|
||||
set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens},
|
||||
)
|
||||
.returning(MonthlyTokenUsage.tokens_used)
|
||||
)
|
||||
used: int = (await db.execute(stmt)).scalar_one()
|
||||
await db.commit()
|
||||
else:
|
||||
# Portable fallback — used in tests (SQLite) and any non-PG engine.
|
||||
row = (
|
||||
await db.execute(
|
||||
select(MonthlyTokenUsage).where(
|
||||
MonthlyTokenUsage.user_id == user_id,
|
||||
MonthlyTokenUsage.year_month == ym,
|
||||
MonthlyTokenUsage.feature == feature,
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if row is None:
|
||||
row = MonthlyTokenUsage(
|
||||
user_id=user_id,
|
||||
year_month=ym,
|
||||
feature=feature,
|
||||
tokens_used=tokens,
|
||||
)
|
||||
db.add(row)
|
||||
else:
|
||||
row.tokens_used += tokens
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(row)
|
||||
used = row.tokens_used
|
||||
|
||||
exhausted = cap is not None and cap != -1 and used >= cap
|
||||
return TokenUsageResult(tokens_used=used, exhausted=exhausted)
|
||||
@@ -29,6 +29,8 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"realtime_extraction": False, # batch queue (Phase 2)
|
||||
"relational_memory": False, # relational tier (Phase 3) — Pro+
|
||||
"proactive_mining": False, # Power+ only (Phase 5)
|
||||
"folder_max_files": 200,
|
||||
"folder_monthly_tokens": 100_000,
|
||||
},
|
||||
"pro": {
|
||||
"agents": -1, # unlimited
|
||||
@@ -41,6 +43,8 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
||||
"relational_memory": True, # person/project predicates
|
||||
"proactive_mining": False, # Power+ only (Phase 5)
|
||||
"folder_max_files": 5000,
|
||||
"folder_monthly_tokens": 2_000_000,
|
||||
},
|
||||
"power": {
|
||||
"agents": -1,
|
||||
@@ -53,6 +57,8 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"realtime_extraction": True,
|
||||
"relational_memory": True, # all predicates incl. custom
|
||||
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||
"folder_max_files": -1, # unlimited
|
||||
"folder_monthly_tokens": -1, # unlimited
|
||||
},
|
||||
"team": {
|
||||
"agents": -1,
|
||||
@@ -65,6 +71,8 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"realtime_extraction": True,
|
||||
"relational_memory": True, # all predicates incl. custom
|
||||
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||
"folder_max_files": -1, # unlimited
|
||||
"folder_monthly_tokens": -1, # unlimited
|
||||
},
|
||||
}
|
||||
|
||||
@@ -123,6 +131,13 @@ class TierManager:
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
def get_feature_value(self, tier: BillingTier, feature: str) -> int:
|
||||
"""Return integer feature value for tier. -1 means unlimited."""
|
||||
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||
if not isinstance(value, int):
|
||||
return 0
|
||||
return value
|
||||
|
||||
# ── Rate limiting ────────────────────────────────────────────────────
|
||||
|
||||
def get_rate_limit(self, tier: BillingTier) -> int:
|
||||
|
||||
@@ -29,6 +29,7 @@ class Settings(BaseSettings):
|
||||
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
|
||||
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
||||
LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs)
|
||||
LLM_MODEL_TASK_BRIEF_AGENT: str = "" # task-brief-agent (per-task deep research)
|
||||
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
||||
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
|
||||
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)
|
||||
|
||||
@@ -21,6 +21,7 @@ from app.core.deep_agent import (
|
||||
_relational_memory_injection,
|
||||
_run_single_agent_stream,
|
||||
_trace_id_from_context,
|
||||
build_brief_multi_project_manifest,
|
||||
)
|
||||
from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback
|
||||
|
||||
@@ -159,6 +160,8 @@ async def run_home_brief(
|
||||
Yields (event_type, data) tuples identical to _run_single_agent_stream.
|
||||
Do NOT post-process output through _normalize_tagged_list_lines.
|
||||
"""
|
||||
from app.agents.folder_agent import FOLDER_TOOLS
|
||||
|
||||
trace_id = _trace_id_from_context(context)
|
||||
today = date.today().isoformat()
|
||||
language = _resolve_language(context)
|
||||
@@ -171,7 +174,10 @@ async def run_home_brief(
|
||||
if today not in system_prompt:
|
||||
system_prompt += f"\nToday is {today}."
|
||||
|
||||
tools = _build_read_tools(user_id, trace_id)
|
||||
brief_manifest = await build_brief_multi_project_manifest()
|
||||
system_prompt = system_prompt + ("\n\n" + brief_manifest if brief_manifest else "")
|
||||
|
||||
tools = [*_build_read_tools(user_id, trace_id), *FOLDER_TOOLS]
|
||||
async for event in _run_single_agent_stream(
|
||||
user_id=user_id,
|
||||
system_prompt=system_prompt,
|
||||
|
||||
@@ -12,8 +12,10 @@ from typing import Any, Literal
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.agents.client_agent import CLIENT_TOOLS
|
||||
from app.agents.note_agent import NOTE_TOOLS
|
||||
from app.agents.project_agent import PROJECT_TOOLS
|
||||
from app.agents.relations_agent import make_query_relations_tool
|
||||
from app.agents.task_agent import TASK_TOOLS
|
||||
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||
from app.core.agent_session_buffer import session_buffer
|
||||
@@ -58,6 +60,93 @@ def _language_instruction(context: dict[str, Any]) -> str:
|
||||
f"All your output text must be written in {lang}."
|
||||
)
|
||||
|
||||
MANIFEST_TOKEN_BUDGET = 3000 # rough budget for <linked_folder> block
|
||||
|
||||
|
||||
def format_folder_manifest(manifest: dict | None) -> str:
|
||||
"""Format a folder manifest into the <linked_folder> block.
|
||||
|
||||
Truncates by mtime DESC if estimated tokens exceed MANIFEST_TOKEN_BUDGET.
|
||||
Returns empty string if manifest is None or has no files.
|
||||
"""
|
||||
if not manifest or not manifest.get("files"):
|
||||
return ""
|
||||
files = list(manifest["files"])
|
||||
files.sort(key=lambda f: f.get("mtimeMs", 0), reverse=True)
|
||||
|
||||
header = (
|
||||
f"<linked_folder>\npath: {manifest.get('folderPath', '?')} "
|
||||
f"({len(files)} files, scanned {manifest.get('lastScannedAt', '?')})\nfiles:\n"
|
||||
)
|
||||
footer_template = "… {} more files omitted, use read_project_folder_file to access by path\n</linked_folder>"
|
||||
|
||||
char_budget = MANIFEST_TOKEN_BUDGET * 4 # ~4 chars/token
|
||||
body = ""
|
||||
included = 0
|
||||
for f in files:
|
||||
line = f"- /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}\n"
|
||||
if len(header) + len(body) + len(line) + len(footer_template.format(0)) > char_budget:
|
||||
break
|
||||
body += line
|
||||
included += 1
|
||||
omitted = len(files) - included
|
||||
if omitted > 0:
|
||||
return header + body + footer_template.format(omitted)
|
||||
return header + body + "</linked_folder>"
|
||||
|
||||
|
||||
async def _fetch_project_manifest(project_id: str) -> dict | None:
|
||||
"""Fetch manifest from Electron via execute_on_client. Returns None if unlinked or error."""
|
||||
from app.core.ws_context import execute_on_client
|
||||
try:
|
||||
result = await execute_on_client(
|
||||
action="read_project_folder_manifest",
|
||||
data={"projectId": project_id},
|
||||
)
|
||||
if not result or not result.get("folderPath"):
|
||||
return None
|
||||
return result
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def build_brief_multi_project_manifest() -> str:
|
||||
"""Build a compact multi-project manifest for the daily brief agent.
|
||||
|
||||
Calls execute_on_client('list_projects_with_folder_manifests') and keeps
|
||||
the top 5 most-recently-modified files per project.
|
||||
"""
|
||||
try:
|
||||
result = await execute_on_client(
|
||||
action="list_projects_with_folder_manifests",
|
||||
data={},
|
||||
)
|
||||
except Exception:
|
||||
return ""
|
||||
projects = (result or {}).get("projects") or []
|
||||
if not projects:
|
||||
return ""
|
||||
blocks: list[str] = ["<linked_folders>"]
|
||||
any_entry = False
|
||||
for p in projects:
|
||||
all_files = p.get("files", []) or []
|
||||
files = sorted(all_files, key=lambda f: f.get("mtimeMs", 0), reverse=True)[:5]
|
||||
blocks.append(f"project: {p.get('projectName','?')} [{p.get('projectId','?')}]")
|
||||
blocks.append(f" path: {p.get('folderPath','?')} (scanned {p.get('lastScannedAt','?')})")
|
||||
if not all_files:
|
||||
blocks.append(" (no indexed files yet — folder is linked but empty or unscanned)")
|
||||
else:
|
||||
for f in files:
|
||||
blocks.append(f" - /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}")
|
||||
if len(all_files) > 5:
|
||||
blocks.append(f" … {len(all_files) - 5} more files (use read_project_folder_file by relPath)")
|
||||
any_entry = True
|
||||
if not any_entry:
|
||||
return ""
|
||||
blocks.append("</linked_folders>")
|
||||
return "\n".join(blocks)
|
||||
|
||||
|
||||
def _datetime_context_injection(context: dict[str, Any]) -> str:
|
||||
"""Build a comprehensive DATE CONTEXT block with pre-computed ms-epoch boundaries for common ranges."""
|
||||
fp = context.get("format_prefs")
|
||||
@@ -303,6 +392,80 @@ For specific dates not listed, compute local-midnight in the user timezone and c
|
||||
{request_context}\
|
||||
"""
|
||||
|
||||
_TASK_BRIEF_RESEARCH_SYSTEM_PROMPT = """\
|
||||
You are an executive assistant preparing a briefing dossier for your principal before they act on a specific task.
|
||||
Your job: gather all relevant context, synthesize it into a tight actionable dossier, and — if the task requires writing (email, message, document) — produce a ready-to-use draft.{user_identity}
|
||||
|
||||
# Research workflow
|
||||
Follow these steps in order, using tools:
|
||||
1. Read the task fully (title, description, due date, priority, status, project, comments).
|
||||
2. Fetch the parent project (`get_project`) to understand scope, aiSummary, and any linked client.
|
||||
3. If the project has a clientId: call `get_client(id)` to retrieve full client details.
|
||||
4. Call `query_relations` (subject_label=client_name or task subject) to find cross-project connections — e.g. the same client appearing in multiple projects.
|
||||
5. Search associative memory (`search_associative`) and archival memory (`archival_memory_search`) using the task title + client name as query phrases to surface relevant past interactions.
|
||||
6. Read core memory blocks for tone preference, language, and user style: `memory_get("tone_preference")`, `memory_get("language")`.
|
||||
7. Determine task kind: is this a writing task (email reply, message, follow-up, proposal)? If yes, draft a ready-to-send piece.
|
||||
|
||||
# Output structure
|
||||
Write the briefing in the user's language. Use this exact structure:
|
||||
|
||||
**What needs to be done**
|
||||
(1–2 sentences, concrete and specific — what action the user must take)
|
||||
|
||||
**Context you should know**
|
||||
(bullet points covering: client background, related projects, prior interactions, tone/style notes, any relevant deadlines or dependencies)
|
||||
|
||||
**Suggested first step**
|
||||
(one specific, immediately actionable instruction)
|
||||
|
||||
If this is a writing task, append a canvas block at the very end:
|
||||
<canvas kind="email|document|message">
|
||||
...ready-to-use draft here...
|
||||
</canvas>
|
||||
|
||||
Do NOT include the canvas block for non-writing tasks.
|
||||
Do NOT repeat verbatim task fields the user already sees in the UI.
|
||||
Be concrete — no vague advice. Every bullet should be a fact that changes what the user does.
|
||||
|
||||
# Date context
|
||||
{date_context}
|
||||
|
||||
# Language
|
||||
{language_instruction}
|
||||
|
||||
# Known people & projects
|
||||
{relational_memory}
|
||||
|
||||
# Request context
|
||||
{request_context}\
|
||||
"""
|
||||
|
||||
_TASK_BRIEF_FOLLOWUP_SYSTEM_PROMPT = """\
|
||||
You are an executive assistant continuing a conversation with your principal.
|
||||
You have already prepared and delivered a research briefing for the active task. The user has read it.{user_identity}
|
||||
|
||||
Your briefing:
|
||||
---
|
||||
{briefing_context}
|
||||
---
|
||||
|
||||
Continue from here. Do NOT repeat the briefing. Refer to it when relevant.
|
||||
Help the user execute: edit drafts, refine wording, look up additional details, plan next steps.
|
||||
Stay terse — your principal is a busy executive.
|
||||
|
||||
# Date context
|
||||
{date_context}
|
||||
|
||||
# Language
|
||||
{language_instruction}
|
||||
|
||||
# Known people & projects
|
||||
{relational_memory}
|
||||
|
||||
# Request context
|
||||
{request_context}\
|
||||
"""
|
||||
|
||||
_FLOATING_DOMAIN_CLASSIFIER_PROMPT = (
|
||||
"You are a strict domain classifier for websocket floating requests. "
|
||||
"Return ONLY a JSON object with keys: type, id, section. "
|
||||
@@ -679,6 +842,25 @@ def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||
lines = [f"- {item}" for item in results]
|
||||
return "Recall memory results:\n" + "\n".join(lines)
|
||||
|
||||
@tool
|
||||
async def search_associative(query: str, limit: int = 5) -> str:
|
||||
"""Semantic search across associative (archival) memory for a given query.
|
||||
|
||||
Use this to surface long-term memories related to a topic, client, or task
|
||||
that may not appear in recent episodes.
|
||||
|
||||
query: natural-language search phrase.
|
||||
limit: max results (default 5).
|
||||
"""
|
||||
logger.info("deep_agent: search_associative trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
results = await memory.search_archival(user_id, query, top_k=limit)
|
||||
if not results:
|
||||
return "No associative memory results found."
|
||||
lines = [f"- {item}" for item in results]
|
||||
return "Associative memory results:\n" + "\n".join(lines)
|
||||
|
||||
return [
|
||||
memory_list_blocks,
|
||||
memory_get,
|
||||
@@ -689,16 +871,33 @@ def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||
archival_memory_insert,
|
||||
archival_memory_search,
|
||||
conversation_search,
|
||||
search_associative,
|
||||
]
|
||||
|
||||
|
||||
def _read_only_memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||
"""Return memory tools that only read — safe for the read-only brief-agent subset."""
|
||||
all_mem = _memory_tools(user_id, trace_id)
|
||||
_read_names = {"memory_list_blocks", "memory_get", "archival_memory_search", "conversation_search"}
|
||||
_read_names = {
|
||||
"memory_list_blocks", "memory_get", "archival_memory_search",
|
||||
"conversation_search", "search_associative",
|
||||
}
|
||||
return [t for t in all_mem if t.name in _read_names]
|
||||
|
||||
|
||||
def _brief_research_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||
"""Return the full tool palette for Stage-1 task brief research (read-only)."""
|
||||
return [
|
||||
*TASK_TOOLS,
|
||||
*PROJECT_TOOLS,
|
||||
*NOTE_TOOLS,
|
||||
*TIMELINE_TOOLS,
|
||||
*CLIENT_TOOLS,
|
||||
*_read_only_memory_tools(user_id, trace_id),
|
||||
make_query_relations_tool(user_id, trace_id),
|
||||
]
|
||||
|
||||
|
||||
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
|
||||
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
|
||||
|
||||
@@ -1216,9 +1415,26 @@ async def run_home_stream(
|
||||
user_id: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
project_id: str | None = None,
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
from app.agents.folder_agent import FOLDER_TOOLS
|
||||
|
||||
prepared_context = await _prepare_context(message, context)
|
||||
system_prompt, langfuse_prompt = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, prepared_context)
|
||||
|
||||
manifest_block = ""
|
||||
if project_id:
|
||||
manifest = await _fetch_project_manifest(project_id)
|
||||
manifest_block = format_folder_manifest(manifest)
|
||||
if not manifest_block:
|
||||
# No specific project context — surface all linked folders so the agent
|
||||
# can answer questions like "tell me about project X" using its files.
|
||||
manifest_block = await build_brief_multi_project_manifest()
|
||||
system_prompt = system_prompt + ("\n\n" + manifest_block if manifest_block else "")
|
||||
|
||||
trace_id = _trace_id_from_context(prepared_context)
|
||||
tools = [*_all_tools_for_user(user_id, trace_id), *FOLDER_TOOLS]
|
||||
|
||||
text_chunks: list[str] = []
|
||||
async for event in _run_single_agent_stream(
|
||||
user_id=user_id,
|
||||
@@ -1227,6 +1443,7 @@ async def run_home_stream(
|
||||
context=prepared_context,
|
||||
langfuse_prompt=langfuse_prompt,
|
||||
agent_name="home-agent",
|
||||
tools=tools,
|
||||
conversation_history=context.get("conversation_history"),
|
||||
):
|
||||
event_type, data = event
|
||||
@@ -1249,6 +1466,28 @@ async def run_floating_stream(
|
||||
domain = await _infer_floating_domain(message, prepared_context)
|
||||
yield "floating_domain", domain
|
||||
|
||||
brief_mode: bool = bool(context.get("brief_mode"))
|
||||
briefing_context_text: str = str(context.get("briefing_context") or "").strip()
|
||||
|
||||
if brief_mode and briefing_context_text:
|
||||
# Stage 2: inject briefing as ground truth context.
|
||||
# Pre-substitute {briefing_context} in the template (handles both Langfuse {{}} and fallback {})
|
||||
# before compile_prompt sees the remaining standard variables.
|
||||
template, langfuse_prompt = get_prompt_or_fallback(
|
||||
"task_brief_followup_system",
|
||||
_TASK_BRIEF_FOLLOWUP_SYSTEM_PROMPT,
|
||||
)
|
||||
system_prompt = compile_prompt(
|
||||
template, langfuse_prompt,
|
||||
date_context=_datetime_context_injection(prepared_context).strip(),
|
||||
language_instruction=_language_instruction(prepared_context).strip(),
|
||||
user_identity=_user_identity_injection(prepared_context).strip(),
|
||||
relational_memory=_relational_memory_injection(prepared_context).strip(),
|
||||
proactive_hints=_proactive_hints_injection(prepared_context).strip(),
|
||||
request_context=_request_context_block(prepared_context),
|
||||
briefing_context=briefing_context_text,
|
||||
)
|
||||
else:
|
||||
system_prompt, langfuse_prompt = _build_system_prompt("floating_system", _FLOATING_SYSTEM_PROMPT, prepared_context)
|
||||
sanitizer = _FloatingStreamSanitizer()
|
||||
emitted_sanitized = False
|
||||
@@ -1283,6 +1522,58 @@ async def run_floating_stream(
|
||||
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
|
||||
|
||||
|
||||
async def run_task_brief_research_stream(
|
||||
user_id: str,
|
||||
task_id: str,
|
||||
context: dict[str, Any],
|
||||
project_id: str | None = None,
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""Stage-1 executive assistant: deep research for one task.
|
||||
|
||||
Yields ``("token", chunk)`` events like other stream runners.
|
||||
The final concatenated text may contain a ``<canvas kind="...">...</canvas>`` block
|
||||
which the WS handler strips and emits as a ``canvas_draft`` mutation.
|
||||
"""
|
||||
from app.agents.folder_agent import FOLDER_TOOLS
|
||||
|
||||
prepared_context = await _prepare_context(f"task:{task_id}", context)
|
||||
tools = [*_brief_research_tools(user_id, _trace_id_from_context(prepared_context)), *FOLDER_TOOLS]
|
||||
|
||||
# Inject task_id so the agent knows what to look up first.
|
||||
research_message = (
|
||||
f"Prepare a briefing dossier for task ID: {task_id}\n"
|
||||
"Follow the research workflow: read the task, then project, then client, "
|
||||
"then cross-project relations, then relevant memory. "
|
||||
"End with a concrete suggested first step. "
|
||||
"If this is a writing task, include a <canvas kind=\"...\"> draft."
|
||||
)
|
||||
|
||||
system_prompt, langfuse_prompt = _build_system_prompt(
|
||||
"task_brief_research_system",
|
||||
_TASK_BRIEF_RESEARCH_SYSTEM_PROMPT,
|
||||
prepared_context,
|
||||
)
|
||||
|
||||
manifest_block = ""
|
||||
if project_id:
|
||||
manifest = await _fetch_project_manifest(project_id)
|
||||
manifest_block = format_folder_manifest(manifest)
|
||||
system_prompt = system_prompt + ("\n\n" + manifest_block if manifest_block else "")
|
||||
|
||||
async for event in _run_single_agent_stream(
|
||||
user_id=user_id,
|
||||
system_prompt=system_prompt,
|
||||
message=research_message,
|
||||
context=prepared_context,
|
||||
max_steps=12,
|
||||
langfuse_prompt=langfuse_prompt,
|
||||
agent_name="task-brief-agent",
|
||||
tools=tools,
|
||||
conversation_history=None,
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
async def update_core_memory(user_id: str, key: str, value: str) -> None:
|
||||
"""Compatibility helper kept for callers that expect explicit memory update API."""
|
||||
async with async_session() as db:
|
||||
|
||||
183
app/core/folder_indexer.py
Normal file
183
app/core/folder_indexer.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Per-file summarisation for project folder integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from pypdf import PdfReader
|
||||
from docx import Document as DocxDocument
|
||||
|
||||
from app.core.langfuse_client import (
|
||||
compile_prompt,
|
||||
extract_usage,
|
||||
get_langfuse,
|
||||
get_prompt_or_fallback,
|
||||
)
|
||||
from app.core.llm import get_llm
|
||||
|
||||
_TEXT_FALLBACK = (
|
||||
"You are summarising a file for an AI assistant that helps the user manage a project.\n"
|
||||
"Produce a single sentence (<=30 words, <=200 chars) that captures the file's purpose "
|
||||
"and most important detail.\nFile extension: {ext}\nFile name: {name}\nContent (truncated if long):\n{content}"
|
||||
)
|
||||
_IMAGE_FALLBACK = (
|
||||
"You are summarising an image attached to a project folder.\n"
|
||||
"Produce a single sentence (<=30 words, <=200 chars) describing what the image shows "
|
||||
"and any obvious purpose (logo, screenshot, diagram, photo of a whiteboard, etc.)."
|
||||
)
|
||||
_MAX_INPUT_CHARS = 6000
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexResult:
|
||||
summary: str
|
||||
tokens_used: int
|
||||
|
||||
|
||||
async def _llm_text(messages: list) -> object:
|
||||
"""Make the LLM call for text summarisation.
|
||||
|
||||
Defined as a standalone async function so tests can patch it cleanly
|
||||
without needing to mock the LLM object itself.
|
||||
"""
|
||||
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
|
||||
return await llm.ainvoke(messages)
|
||||
|
||||
|
||||
async def _llm_vision(messages: list) -> object:
|
||||
"""Make the LLM call for vision (image) summarisation.
|
||||
|
||||
Accepts the message list and returns the response directly, mirroring
|
||||
the ``_llm_text`` caller pattern so tests can patch it at the module level.
|
||||
"""
|
||||
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
|
||||
return await llm.ainvoke(messages)
|
||||
|
||||
|
||||
async def summarize_image(*, image_b64: str, mime: str, file_name: str | None = None) -> IndexResult:
|
||||
"""Return a compact summary of an image file using vision.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image_b64:
|
||||
Base64-encoded image bytes.
|
||||
mime:
|
||||
MIME type of the image, e.g. ``"image/png"``.
|
||||
file_name:
|
||||
Optional file name, attached to the Langfuse trace as input metadata.
|
||||
"""
|
||||
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_image", _IMAGE_FALLBACK)
|
||||
messages = [
|
||||
SystemMessage(content=template),
|
||||
HumanMessage(content=[
|
||||
{"type": "text", "text": "Summarise this image."},
|
||||
{"type": "image_url", "image_url": {"url": f"data:{mime};base64,{image_b64}"}},
|
||||
]),
|
||||
]
|
||||
lf = get_langfuse()
|
||||
if lf is not None:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="folder-summarize-image",
|
||||
model="gpt-4o-mini",
|
||||
prompt=prompt_obj,
|
||||
input={"file_name": file_name, "mime": mime},
|
||||
) as gen:
|
||||
response = await _llm_vision(messages)
|
||||
usage = extract_usage(response)
|
||||
gen.update(output=response.content, usage_details=usage)
|
||||
else:
|
||||
response = await _llm_vision(messages)
|
||||
usage = extract_usage(response)
|
||||
summary = (response.content or "").strip()[:500]
|
||||
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
|
||||
|
||||
|
||||
async def summarize_text(*, content: str, ext: str, name: str) -> IndexResult:
|
||||
"""Return a compact summary of a text file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
content:
|
||||
Raw text content of the file (will be truncated to _MAX_INPUT_CHARS).
|
||||
ext:
|
||||
File extension including the leading dot, e.g. ``".md"``.
|
||||
name:
|
||||
File name, e.g. ``"kickoff.md"``.
|
||||
"""
|
||||
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_text", _TEXT_FALLBACK)
|
||||
truncated = content[:_MAX_INPUT_CHARS]
|
||||
compiled = compile_prompt(template, prompt_obj, ext=ext, name=name, content=truncated)
|
||||
messages = [
|
||||
SystemMessage(content=compiled),
|
||||
HumanMessage(content="Summarise this file."),
|
||||
]
|
||||
lf = get_langfuse()
|
||||
if lf is not None:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="folder-summarize-text",
|
||||
model="gpt-4o-mini",
|
||||
prompt=prompt_obj,
|
||||
input={"file_name": name, "ext": ext, "content_chars": len(truncated)},
|
||||
) as gen:
|
||||
response = await _llm_text(messages)
|
||||
usage = extract_usage(response)
|
||||
gen.update(output=response.content, usage_details=usage)
|
||||
else:
|
||||
response = await _llm_text(messages)
|
||||
usage = extract_usage(response)
|
||||
summary = (response.content or "").strip()[:500]
|
||||
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
|
||||
|
||||
|
||||
def _extract_pdf_text(pdf_b64: str) -> str:
|
||||
buf = io.BytesIO(base64.b64decode(pdf_b64))
|
||||
reader = PdfReader(buf)
|
||||
parts: list[str] = []
|
||||
for page in reader.pages:
|
||||
try:
|
||||
parts.append(page.extract_text() or "")
|
||||
except Exception:
|
||||
continue
|
||||
return "\n".join(parts).strip()
|
||||
|
||||
|
||||
def _extract_docx_text(docx_b64: str) -> str:
|
||||
buf = io.BytesIO(base64.b64decode(docx_b64))
|
||||
doc = DocxDocument(buf)
|
||||
return "\n".join(p.text for p in doc.paragraphs if p.text).strip()
|
||||
|
||||
|
||||
async def summarize_pdf(*, pdf_b64: str, name: str) -> IndexResult:
|
||||
"""Return a compact summary of a PDF file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pdf_b64:
|
||||
Base64-encoded PDF bytes.
|
||||
name:
|
||||
File name, e.g. ``"report.pdf"``.
|
||||
"""
|
||||
text = _extract_pdf_text(pdf_b64)
|
||||
if not text:
|
||||
return IndexResult(summary="Could not extract text", tokens_used=0)
|
||||
return await summarize_text(content=text, ext=".pdf", name=name)
|
||||
|
||||
|
||||
async def summarize_docx(*, docx_b64: str, name: str) -> IndexResult:
|
||||
"""Return a compact summary of a DOCX file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
docx_b64:
|
||||
Base64-encoded DOCX bytes.
|
||||
name:
|
||||
File name, e.g. ``"spec.docx"``.
|
||||
"""
|
||||
text = _extract_docx_text(docx_b64)
|
||||
if not text:
|
||||
return IndexResult(summary="Could not extract text", tokens_used=0)
|
||||
return await summarize_text(content=text, ext=".docx", name=name)
|
||||
@@ -107,6 +107,7 @@ _AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
||||
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
||||
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
||||
"brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL,
|
||||
"task-brief-agent": lambda: settings.LLM_MODEL_TASK_BRIEF_AGENT or settings.LLM_MODEL,
|
||||
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
||||
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
|
||||
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
|
||||
|
||||
@@ -2,11 +2,35 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
||||
|
||||
# Matches <canvas kind="...">...</canvas> blocks (single-line or multiline).
|
||||
_CANVAS_BLOCK_RE = re.compile(
|
||||
r'<canvas\s+kind=["\']([^"\']+)["\']>(.*?)</canvas>',
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def extract_canvas_block(text: str) -> tuple[str, str | None, str | None]:
|
||||
"""Strip the first <canvas kind="...">...</canvas> block from *text*.
|
||||
|
||||
Returns ``(visible_text, canvas_content, canvas_kind)``.
|
||||
``canvas_content`` and ``canvas_kind`` are ``None`` when no block is found.
|
||||
"""
|
||||
match = _CANVAS_BLOCK_RE.search(text)
|
||||
if not match:
|
||||
return text, None, None
|
||||
|
||||
canvas_kind = match.group(1).strip()
|
||||
canvas_content = match.group(2).strip()
|
||||
visible = text[: match.start()] + text[match.end() :]
|
||||
visible = visible.strip()
|
||||
return visible, canvas_content, canvas_kind
|
||||
|
||||
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||
|
||||
|
||||
|
||||
@@ -243,6 +243,7 @@ class AgentRunLog(Base):
|
||||
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
|
||||
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
|
||||
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||
started_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
@@ -263,6 +264,17 @@ class AgentRunLog(Base):
|
||||
)
|
||||
|
||||
|
||||
class MonthlyTokenUsage(Base):
|
||||
__tablename__ = "monthly_token_usage"
|
||||
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
year_month: Mapped[str] = mapped_column(String(7), primary_key=True) # 'YYYY-MM'
|
||||
feature: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
|
||||
# ── Memory models ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -87,6 +87,15 @@ class WsFrameType(str, Enum):
|
||||
journey_reply = "journey_reply"
|
||||
# ── v5 brief frame types ──────────────────────────────────────────
|
||||
brief_request = "brief_request"
|
||||
# ── v6 task brief frame types ─────────────────────────────────────
|
||||
task_brief_request = "task_brief_request"
|
||||
# ── v7 folder index frame types ───────────────────────────────────
|
||||
index_session_start = "index_session_start"
|
||||
index_file_batch = "index_file_batch"
|
||||
index_session_cancel = "index_session_cancel"
|
||||
index_file_result = "index_file_result"
|
||||
index_session_progress = "index_session_progress"
|
||||
index_session_done = "index_session_done"
|
||||
|
||||
|
||||
class WsToolCall(BaseModel):
|
||||
@@ -209,6 +218,7 @@ class WsStreamEnd(BaseModel):
|
||||
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||
request_id: str
|
||||
error: str | None = None
|
||||
mutations: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class WsDomain(BaseModel):
|
||||
|
||||
@@ -39,3 +39,5 @@ lxml>=5.0.0
|
||||
PyYAML>=6.0.0
|
||||
apscheduler>=3.10.0
|
||||
ruff>=0.8.0
|
||||
pypdf>=4.0
|
||||
python-docx>=1.1
|
||||
|
||||
@@ -17,6 +17,8 @@ from jose import jwt
|
||||
from sqlalchemy import StaticPool, event
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.db import Base, get_session
|
||||
from app.main import app
|
||||
@@ -134,6 +136,38 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st
|
||||
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
||||
|
||||
|
||||
# ── Convenience aliases and per-tier user fixtures ────────────────────
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db(db_session: AsyncSession) -> AsyncSession:
|
||||
"""Alias for db_session — used by folder quota tests."""
|
||||
return db_session
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user_free(db_session: AsyncSession):
|
||||
"""Return the seeded free-tier User row."""
|
||||
result = await db_session.execute(
|
||||
select(User).where(User.id == TEST_USER_IDS["free"])
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user_power(db_session: AsyncSession):
|
||||
"""Return the seeded power-tier User row."""
|
||||
result = await db_session.execute(
|
||||
select(User).where(User.id == TEST_USER_IDS["power"])
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_headers_free() -> dict[str, str]:
|
||||
"""Authorization header for the seeded free-tier user."""
|
||||
return auth_header("free")
|
||||
|
||||
|
||||
# ── CLI options ───────────────────────────────────────────────────────
|
||||
|
||||
def pytest_addoption(parser):
|
||||
|
||||
139
tests/test_folder_agent_tool.py
Normal file
139
tests/test_folder_agent_tool.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.folder_agent import (
|
||||
read_project_folder_file,
|
||||
search_project_folder_file,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def test_happy_path():
|
||||
with patch(
|
||||
"app.agents.folder_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"content": "file body", "kind": "text", "totalSize": 9}),
|
||||
):
|
||||
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "docs/x.md"})
|
||||
assert "file body" in out
|
||||
assert "kind=text" in out
|
||||
|
||||
|
||||
async def test_traversal_rejected():
|
||||
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "../../etc/passwd"})
|
||||
assert out == "Access denied"
|
||||
|
||||
|
||||
async def test_absolute_path_rejected():
|
||||
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "C:\\Windows\\foo"})
|
||||
assert out == "Access denied"
|
||||
|
||||
|
||||
async def test_missing_file():
|
||||
with patch(
|
||||
"app.agents.folder_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"content": "", "kind": "missing", "totalSize": 0}),
|
||||
):
|
||||
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "ghost.md"})
|
||||
assert "not found" in out.lower()
|
||||
|
||||
|
||||
async def test_pagination_signals_more_available():
|
||||
# Electron returned the first slice, totalSize larger than slice length.
|
||||
with patch(
|
||||
"app.agents.folder_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"content": "first chunk", "kind": "text", "totalSize": 1000}),
|
||||
):
|
||||
out = await read_project_folder_file.ainvoke({
|
||||
"project_id": "p1",
|
||||
"relative_path": "big.txt",
|
||||
"offset": 0,
|
||||
"length": 11,
|
||||
})
|
||||
assert "first chunk" in out
|
||||
assert "More content available" in out
|
||||
assert "offset=11" in out
|
||||
|
||||
|
||||
async def test_pdf_extracted_then_sliced(monkeypatch):
|
||||
from app.agents import folder_agent
|
||||
monkeypatch.setattr(folder_agent, "_extract_pdf_text", lambda b: "ABC " * 100)
|
||||
with patch(
|
||||
"app.agents.folder_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"content": "JVBERi0xLg==", "kind": "pdf", "totalSize": 12}),
|
||||
):
|
||||
out = await read_project_folder_file.ainvoke({
|
||||
"project_id": "p1",
|
||||
"relative_path": "doc.pdf",
|
||||
"offset": 0,
|
||||
"length": 8,
|
||||
})
|
||||
assert "kind=pdf" in out
|
||||
assert "ABC ABC " in out
|
||||
assert "More content available" in out
|
||||
|
||||
|
||||
async def test_image_returns_placeholder():
|
||||
with patch(
|
||||
"app.agents.folder_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"content": "iVBORw0K", "kind": "image", "totalSize": 1024}),
|
||||
):
|
||||
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "logo.png"})
|
||||
assert "image" in out.lower()
|
||||
|
||||
|
||||
async def test_search_finds_match_with_context():
|
||||
body = "alpha\nbeta\nthe needle is here\ngamma\ndelta"
|
||||
with patch(
|
||||
"app.agents.folder_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"content": body, "kind": "text", "totalSize": len(body)}),
|
||||
):
|
||||
out = await search_project_folder_file.ainvoke({
|
||||
"project_id": "p1",
|
||||
"relative_path": "log.txt",
|
||||
"query": "needle",
|
||||
"context_lines": 1,
|
||||
})
|
||||
assert "needle" in out
|
||||
assert "matches=1" in out
|
||||
# Context lines included
|
||||
assert "beta" in out
|
||||
assert "gamma" in out
|
||||
|
||||
|
||||
async def test_search_no_match():
|
||||
with patch(
|
||||
"app.agents.folder_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"content": "nothing here", "kind": "text", "totalSize": 12}),
|
||||
):
|
||||
out = await search_project_folder_file.ainvoke({
|
||||
"project_id": "p1",
|
||||
"relative_path": "x.txt",
|
||||
"query": "zzz",
|
||||
})
|
||||
assert "No matches" in out
|
||||
|
||||
|
||||
async def test_search_rejects_traversal():
|
||||
out = await search_project_folder_file.ainvoke({
|
||||
"project_id": "p1",
|
||||
"relative_path": "../etc/passwd",
|
||||
"query": "root",
|
||||
})
|
||||
assert out == "Access denied"
|
||||
|
||||
|
||||
async def test_search_image_rejected():
|
||||
with patch(
|
||||
"app.agents.folder_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"content": "b64data", "kind": "image", "totalSize": 100}),
|
||||
):
|
||||
out = await search_project_folder_file.ainvoke({
|
||||
"project_id": "p1",
|
||||
"relative_path": "logo.png",
|
||||
"query": "anything",
|
||||
})
|
||||
assert "Cannot search" in out
|
||||
83
tests/test_folder_indexer.py
Normal file
83
tests/test_folder_indexer.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Folder indexer LLM helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.folder_indexer import summarize_text, summarize_image, IndexResult
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def test_summarize_text_returns_summary_and_tokens():
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.content = "Kickoff notes covering scope and deadlines."
|
||||
mock_resp.usage_metadata = {"input_tokens": 320, "output_tokens": 18, "total_tokens": 338}
|
||||
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
|
||||
result = await summarize_text(content="hello world", ext=".md", name="kickoff.md")
|
||||
assert isinstance(result, IndexResult)
|
||||
assert result.summary == "Kickoff notes covering scope and deadlines."
|
||||
assert result.tokens_used == 338
|
||||
|
||||
|
||||
async def test_summarize_text_truncates_summary_at_500_chars():
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.content = "x" * 1000
|
||||
mock_resp.usage_metadata = {"total_tokens": 100}
|
||||
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
|
||||
result = await summarize_text(content="x", ext=".md", name="x.md")
|
||||
assert len(result.summary) <= 500
|
||||
|
||||
|
||||
async def test_summarize_image_uses_vision_content_blocks():
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.content = "Final logo on white background."
|
||||
mock_resp.usage_metadata = {"total_tokens": 500}
|
||||
captured = {}
|
||||
|
||||
async def fake_llm_vision(messages):
|
||||
captured["messages"] = messages
|
||||
return mock_resp
|
||||
|
||||
with patch("app.core.folder_indexer._llm_vision", new=fake_llm_vision):
|
||||
result = await summarize_image(image_b64="iVBORw0KG", mime="image/png")
|
||||
|
||||
assert "Final logo" in result.summary
|
||||
assert result.tokens_used == 500
|
||||
# last message contains an image content block
|
||||
last = captured["messages"][-1]
|
||||
assert any(
|
||||
isinstance(p, dict) and p.get("type") == "image_url"
|
||||
for p in (last.content if isinstance(last.content, list) else [])
|
||||
)
|
||||
|
||||
|
||||
async def test_summarize_pdf_extracts_then_summarizes(monkeypatch):
|
||||
# pypdf.PdfReader returns text from pages
|
||||
from app.core import folder_indexer
|
||||
class FakePage:
|
||||
def extract_text(self): return "PDF page content with project info."
|
||||
class FakeReader:
|
||||
pages = [FakePage(), FakePage()]
|
||||
monkeypatch.setattr(folder_indexer, "PdfReader", lambda buf: FakeReader())
|
||||
mock_resp = AsyncMock(); mock_resp.content = "Project info doc."; mock_resp.usage_metadata = {"total_tokens": 50}
|
||||
async def fake_llm(messages): return mock_resp
|
||||
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
|
||||
result = await folder_indexer.summarize_pdf(pdf_b64="SGVsbG8=", name="doc.pdf")
|
||||
assert "Project info" in result.summary
|
||||
assert result.tokens_used == 50
|
||||
|
||||
|
||||
async def test_summarize_docx_extracts_then_summarizes(monkeypatch):
|
||||
from app.core import folder_indexer
|
||||
class FakePara:
|
||||
def __init__(self, t): self.text = t
|
||||
class FakeDoc:
|
||||
paragraphs = [FakePara("Heading"), FakePara("Body paragraph one.")]
|
||||
monkeypatch.setattr(folder_indexer, "DocxDocument", lambda buf: FakeDoc())
|
||||
mock_resp = AsyncMock(); mock_resp.content = "Heading and body."; mock_resp.usage_metadata = {"total_tokens": 30}
|
||||
async def fake_llm(messages): return mock_resp
|
||||
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
|
||||
result = await folder_indexer.summarize_docx(docx_b64="UEsDBBQ=", name="doc.docx")
|
||||
assert result.summary == "Heading and body."
|
||||
94
tests/test_folder_quota.py
Normal file
94
tests/test_folder_quota.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Folder quota helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.billing.quota import (
|
||||
check_folder_quota,
|
||||
add_token_usage,
|
||||
QuotaExceeded,
|
||||
)
|
||||
from app.models import MonthlyTokenUsage
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def test_check_folder_quota_free_rejects_above_file_cap(db, test_user_free):
|
||||
with pytest.raises(QuotaExceeded) as exc:
|
||||
await check_folder_quota(
|
||||
user_id=test_user_free.id, tier="free", estimated_files=500, db=db
|
||||
)
|
||||
assert exc.value.reason == "max_files"
|
||||
|
||||
|
||||
async def test_check_folder_quota_free_passes_under_cap(db, test_user_free):
|
||||
# No raise
|
||||
await check_folder_quota(
|
||||
user_id=test_user_free.id, tier="free", estimated_files=50, db=db
|
||||
)
|
||||
|
||||
|
||||
async def test_check_folder_quota_rejects_when_monthly_exhausted(db, test_user_free):
|
||||
ym = datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
db.add(MonthlyTokenUsage(
|
||||
user_id=test_user_free.id, year_month=ym, feature="folder_index", tokens_used=100_000
|
||||
))
|
||||
await db.commit()
|
||||
with pytest.raises(QuotaExceeded) as exc:
|
||||
await check_folder_quota(
|
||||
user_id=test_user_free.id, tier="free", estimated_files=10, db=db
|
||||
)
|
||||
assert exc.value.reason == "monthly_tokens"
|
||||
|
||||
|
||||
async def test_check_folder_quota_power_unlimited(db, test_user_power):
|
||||
await check_folder_quota(
|
||||
user_id=test_user_power.id, tier="power", estimated_files=999_999, db=db
|
||||
)
|
||||
|
||||
|
||||
async def test_add_token_usage_atomic_increment(db, test_user_free):
|
||||
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=1500, db=db)
|
||||
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=2500, db=db)
|
||||
ym = datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
row = (await db.execute(
|
||||
select(MonthlyTokenUsage).where(
|
||||
MonthlyTokenUsage.user_id == test_user_free.id,
|
||||
MonthlyTokenUsage.year_month == ym,
|
||||
MonthlyTokenUsage.feature == "folder_index",
|
||||
)
|
||||
)).scalar_one()
|
||||
assert row.tokens_used == 4000
|
||||
|
||||
|
||||
async def test_add_token_usage_returns_exhausted_when_over_cap(db, test_user_free):
|
||||
result = await add_token_usage(
|
||||
user_id=test_user_free.id, feature="folder_index", tokens=150_000, db=db, cap=100_000
|
||||
)
|
||||
assert result.exhausted is True
|
||||
assert result.tokens_used == 150_000
|
||||
|
||||
|
||||
def test_quota_check_endpoint_rejects(client, auth_headers_free):
|
||||
res = client.post(
|
||||
"/api/v1/billing/quota/check",
|
||||
json={"feature": "folder_index", "estimated_files": 500},
|
||||
headers=auth_headers_free,
|
||||
)
|
||||
assert res.status_code == 402
|
||||
body = res.json()
|
||||
assert body["detail"]["reason"] == "max_files"
|
||||
|
||||
|
||||
def test_quota_check_endpoint_passes(client, auth_headers_free):
|
||||
res = client.post(
|
||||
"/api/v1/billing/quota/check",
|
||||
json={"feature": "folder_index", "estimated_files": 50},
|
||||
headers=auth_headers_free,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"ok": True}
|
||||
69
tests/test_manifest_injection.py
Normal file
69
tests/test_manifest_injection.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.deep_agent import format_folder_manifest, MANIFEST_TOKEN_BUDGET
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def test_format_folder_manifest_basic():
|
||||
manifest = {
|
||||
"folderPath": "D:\\Acme",
|
||||
"lastScannedAt": "2h ago",
|
||||
"files": [
|
||||
{"relPath": "briefs/kickoff.md", "kind": "text", "summary": "Kickoff notes; scope and deadlines."},
|
||||
{"relPath": "logos/logo-v3.png", "kind": "image", "summary": "Final logo on white."},
|
||||
],
|
||||
}
|
||||
out = format_folder_manifest(manifest)
|
||||
assert "<linked_folder>" in out
|
||||
assert "/briefs/kickoff.md" in out or "briefs/kickoff.md" in out
|
||||
assert "[text]" in out
|
||||
assert "[image]" in out
|
||||
|
||||
|
||||
def test_format_folder_manifest_truncates_past_budget():
|
||||
files = [
|
||||
{"relPath": f"f{i}.md", "kind": "text", "summary": "x" * 100, "mtimeMs": i}
|
||||
for i in range(2000)
|
||||
]
|
||||
out = format_folder_manifest({"folderPath": "p", "lastScannedAt": "now", "files": files})
|
||||
assert "more files omitted" in out
|
||||
# Rough token check
|
||||
assert len(out) // 4 < MANIFEST_TOKEN_BUDGET + 200
|
||||
|
||||
|
||||
def test_format_folder_manifest_null_returns_empty():
|
||||
assert format_folder_manifest(None) == ""
|
||||
assert format_folder_manifest({"files": []}) == ""
|
||||
|
||||
|
||||
async def test_brief_multi_project_manifest_top_5_per_project():
|
||||
fake_response = [
|
||||
{
|
||||
"projectId": "p1", "projectName": "Acme", "folderPath": "/a",
|
||||
"lastScannedAt": "now",
|
||||
"files": [
|
||||
{"relPath": f"f{i}.md", "kind": "text", "summary": "s", "mtimeMs": i}
|
||||
for i in range(10)
|
||||
],
|
||||
},
|
||||
{
|
||||
"projectId": "p2", "projectName": "Beta", "folderPath": "/b",
|
||||
"lastScannedAt": "now",
|
||||
"files": [{"relPath": "x.md", "kind": "text", "summary": "s", "mtimeMs": 1}],
|
||||
},
|
||||
]
|
||||
with patch(
|
||||
"app.core.deep_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"projects": fake_response}),
|
||||
):
|
||||
from app.core.deep_agent import build_brief_multi_project_manifest
|
||||
out = await build_brief_multi_project_manifest()
|
||||
# Project 1 has 10 files, only top 5 by mtimeMs should appear
|
||||
assert out.count("[p1]") <= 5
|
||||
# Project 2 has 1 file, must appear
|
||||
assert "[p2]" in out or "Beta" in out
|
||||
196
tests/test_ws_index_session.py
Normal file
196
tests/test_ws_index_session.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Tests for WS folder index_session handlers (Task 9).
|
||||
|
||||
Tests the three handler functions directly with a minimal fake WebSocket so
|
||||
no real WS connection or LLM call is made.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.api.routes.device_ws import (
|
||||
_handle_index_session_start,
|
||||
_handle_index_file_batch,
|
||||
_handle_index_session_cancel,
|
||||
_index_sessions,
|
||||
)
|
||||
from app.billing.quota import add_token_usage
|
||||
from app.core.folder_indexer import IndexResult
|
||||
from app.models import MonthlyTokenUsage
|
||||
from app.schemas import WsFrameType
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
USER_ID = TEST_USER_IDS["free"]
|
||||
POWER_USER_ID = TEST_USER_IDS["power"]
|
||||
|
||||
|
||||
# ── Fake WebSocket ────────────────────────────────────────────────────
|
||||
|
||||
class _FakeWebSocket:
|
||||
"""Minimal WebSocket stand-in that records send_text calls."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.sent: list[dict] = []
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
self.sent.append(json.loads(text))
|
||||
|
||||
def sent_types(self) -> list[str]:
|
||||
return [f["type"] for f in self.sent]
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _make_session_id() -> str:
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def _fake_summarize_text_factory(summary: str = "A test summary.", tokens: int = 100):
|
||||
"""Return an AsyncMock that resolves to a fixed IndexResult."""
|
||||
async def _fake(**kwargs) -> IndexResult:
|
||||
return IndexResult(summary=summary, tokens_used=tokens)
|
||||
return _fake
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def _clean_sessions():
|
||||
"""Ensure _index_sessions is empty before and after each test."""
|
||||
_index_sessions.clear()
|
||||
yield
|
||||
_index_sessions.clear()
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────────
|
||||
|
||||
async def test_index_session_happy_path(db_session):
|
||||
"""start + batch of 2 text files → 2 index_file_result + 1 progress + 1 done(completed)."""
|
||||
ws = _FakeWebSocket()
|
||||
session_id = _make_session_id()
|
||||
|
||||
# Register session.
|
||||
await _handle_index_session_start(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"projectId": "proj-1",
|
||||
"totalFiles": 2,
|
||||
})
|
||||
|
||||
# Verify session was registered.
|
||||
assert session_id in _index_sessions
|
||||
assert _index_sessions[session_id]["total"] == 2
|
||||
assert _index_sessions[session_id]["processed"] == 0
|
||||
# No response frames expected for session_start.
|
||||
assert ws.sent == []
|
||||
|
||||
# Send batch of 2 text files — patch summarize_text so no LLM call needed.
|
||||
with patch(
|
||||
"app.api.routes.device_ws._handle_index_file_batch.__globals__",
|
||||
# We patch the module-level function in folder_indexer instead:
|
||||
) if False else patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory()):
|
||||
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
|
||||
# Wire db_session into the context manager.
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_async_session.return_value = mock_cm
|
||||
|
||||
await _handle_index_file_batch(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"files": [
|
||||
{"relPath": "README.md", "kind": "text", "content": "hello", "ext": ".md"},
|
||||
{"relPath": "notes.txt", "kind": "text", "content": "world", "ext": ".txt"},
|
||||
],
|
||||
})
|
||||
|
||||
types = ws.sent_types()
|
||||
# Expect 2 file results + 1 progress + 1 done(completed).
|
||||
assert types.count(WsFrameType.index_file_result) == 2
|
||||
assert types.count(WsFrameType.index_session_progress) == 1
|
||||
assert types.count(WsFrameType.index_session_done) == 1
|
||||
|
||||
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||
assert done_frame["status"] == "completed"
|
||||
|
||||
progress_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_progress)
|
||||
assert progress_frame["processed"] == 2
|
||||
assert progress_frame["total"] == 2
|
||||
|
||||
# Verify session cleaned up.
|
||||
assert session_id not in _index_sessions
|
||||
|
||||
|
||||
async def test_index_session_cancel(db_session):
|
||||
"""start then cancel → index_session_done(cancelled)."""
|
||||
ws = _FakeWebSocket()
|
||||
session_id = _make_session_id()
|
||||
|
||||
await _handle_index_session_start(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"totalFiles": 5,
|
||||
})
|
||||
assert session_id in _index_sessions
|
||||
|
||||
await _handle_index_session_cancel(ws, {"sessionId": session_id})
|
||||
|
||||
types = ws.sent_types()
|
||||
assert WsFrameType.index_session_done in types
|
||||
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||
assert done_frame["status"] == "cancelled"
|
||||
|
||||
# Session should be cleaned up.
|
||||
assert session_id not in _index_sessions
|
||||
|
||||
|
||||
async def test_index_session_quota_exceeded(db_session):
|
||||
"""Pre-fill usage to cap → batch one file → index_session_done(quota_exceeded)."""
|
||||
ws = _FakeWebSocket()
|
||||
session_id = _make_session_id()
|
||||
|
||||
# Pre-fill monthly token usage to the free-tier cap (100_000).
|
||||
ym = datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
db_session.add(MonthlyTokenUsage(
|
||||
user_id=USER_ID,
|
||||
year_month=ym,
|
||||
feature="folder_index",
|
||||
tokens_used=100_000, # free tier cap exactly
|
||||
))
|
||||
await db_session.commit()
|
||||
|
||||
await _handle_index_session_start(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"totalFiles": 1,
|
||||
})
|
||||
|
||||
with patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory(tokens=1)):
|
||||
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_async_session.return_value = mock_cm
|
||||
|
||||
await _handle_index_file_batch(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"files": [
|
||||
{"relPath": "file.md", "kind": "text", "content": "content", "ext": ".md"},
|
||||
],
|
||||
})
|
||||
|
||||
types = ws.sent_types()
|
||||
# Should have 1 file result (success) then done(quota_exceeded).
|
||||
assert WsFrameType.index_file_result in types
|
||||
assert WsFrameType.index_session_done in types
|
||||
|
||||
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||
assert done_frame["status"] == "quota_exceeded"
|
||||
|
||||
# Session should be cleaned up.
|
||||
assert session_id not in _index_sessions
|
||||
Reference in New Issue
Block a user