27 Commits

Author SHA1 Message Date
Roberto
cc0e258e8c fix(api): WS index frames accept both camelCase and snake_case keys (Electron toSnakeCase compat) 2026-05-13 08:58:46 +02:00
Roberto
12e203e63d fix(api): multi-project manifest lists projects even with zero indexed files 2026-05-12 18:10:57 +02:00
Roberto
ffcd7390f0 feat(api): pagination + search + PDF/DOCX extract in folder agent tools 2026-05-12 17:31:43 +02:00
Roberto
91e880f9d4 fix(api): home agent falls back to multi-project folder manifest when no project_id 2026-05-12 16:54:47 +02:00
Roberto
7d47ca54be feat(api): emit Langfuse generation traces for folder indexer 2026-05-12 16:40:20 +02:00
Roberto
956fa88853 feat(api): multi-project folder manifest for daily brief
Add build_brief_multi_project_manifest() to deep_agent.py that fetches
all project folder manifests via execute_on_client and keeps the top 5
most-recently-modified files per project. Wire into run_home_brief in
brief_agent.py, injecting the <linked_folders> block into the system
prompt alongside FOLDER_TOOLS.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:40:47 +02:00
Roberto
fb2f59ccea feat(api): inject folder manifest into home agent when project context active
Add optional project_id param to run_home_stream. When set, fetch the linked
folder manifest via _fetch_project_manifest and prepend the <linked_folder>
block to the system prompt. Also build an explicit tools list that extends
_all_tools_for_user with FOLDER_TOOLS so the home agent can read folder
files. device_ws._handle_home_request extracts project_id / projectId from
the home_request frame and forwards it to the runner.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:32:20 +02:00
Roberto
56dbb7f4cd feat(api): inject folder manifest into task brief agent
Add _fetch_project_manifest helper that calls read_project_folder_manifest
via execute_on_client. Wire it into run_task_brief_research_stream (new
optional project_id param) so the <linked_folder> block is prepended to the
system prompt when the task belongs to a linked project. Also bind
FOLDER_TOOLS into the task-brief tool palette so the agent can read folder
files. device_ws extracts project_id / projectId from the task_brief_request
frame and forwards it.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:31:21 +02:00
Roberto
506f517851 feat(api): manifest formatter with token-budget truncation 2026-05-12 11:28:13 +02:00
Roberto
520c186991 feat(api): scoped read_project_folder_file tool with traversal guard
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:26:02 +02:00
Roberto
582bf27deb feat(api): WS index_session frames + handlers
Add six v7 WsFrameType enum members (index_session_start/cancel/batch,
index_file_result/progress/done), wire dispatch in device_ws message loop,
and implement _handle_index_session_start/cancel/file_batch with per-file
summarisation, token accounting, and quota enforcement.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:22:20 +02:00
Roberto
2aeb453229 feat(api): PDF + DOCX extraction in folder indexer
Add pypdf/python-docx deps, _extract_pdf_text/_extract_docx_text helpers,
and summarize_pdf/summarize_docx wrappers that delegate to summarize_text.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:15:17 +02:00
Roberto
b7a4edac90 feat(api): folder_indexer.summarize_image via gpt-4o-mini vision
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:09:37 +02:00
Roberto
822b4cd8b1 feat(api): folder_indexer.summarize_text via gpt-4o-mini 2026-05-12 11:05:43 +02:00
Roberto
ab24fc4c91 feat(api): POST /billing/quota/check endpoint
Pre-flight quota check for folder_index. Returns 402 with reason
when file cap or monthly token budget would be exceeded; 200 {"ok": true}
otherwise. Also adds auth_headers_free fixture to conftest.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 09:14:56 +02:00
Roberto
a98e99f7a2 feat(api): folder quota helpers with atomic token usage
Implements check_folder_quota and add_token_usage in app/billing/quota.py
with dialect-aware upsert (pg_insert on PostgreSQL, read-then-write on SQLite).
Adds test_user_free/test_user_power fixtures and db alias to conftest.py.
6 new tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 08:23:22 +02:00
Roberto
a0ff285bcd feat(api): tier features for folder integration
Add folder_max_files and folder_monthly_tokens to all four tier dicts
in FEATURES, and add get_feature_value() helper to TierManager.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:39:36 +02:00
Roberto
177c1a87dd feat(api): MonthlyTokenUsage model + AgentRunLog.tokens_used
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:30:33 +02:00
Roberto
441a4ea05c chore(api): fix stale Revises comment in folder migration 2026-05-12 07:21:13 +02:00
Roberto
a693a64bf5 feat(api): add migration for folder token tracking
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:16:23 +02:00
Roberto
67562b8092 Add task brief research agent: Stage 1 deep-research + canvas draft emission
- run_task_brief_research() runner with brief-specific tool set and max_steps=12
- New agents: client_agent (list_clients, get_client) and relations_agent (query_relations)
- search_associative tool wrapping MemoryMiddleware semantic search
- BRIEF_RESEARCH_TOOLS constant: read-only task/project/note/timeline + memory + client/relations
- canvas block extraction in output_formatter (splits visible text from <canvas> draft)
- device_ws.py: task_brief_research request type; emits canvas_draft mutation on stream_end
- Stage 2 briefMode: briefing_context injected into floating system prompt when present
- briefingContext kwarg wired through compile_prompt call chain

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-04 15:09:58 +02:00
Roberto
6f4c68b359 Update note management from db vector to index 2026-04-30 00:11:17 +02:00
Roberto
c20c6d7853 Fix home message tools calls 2026-04-29 09:21:41 +02:00
Roberto
6787e690ba fix tools calls 2026-04-27 09:15:08 +02:00
Roberto
cb8f56d909 date format fix 2026-04-26 21:06:38 +02:00
Roberto
2c7cac9e03 Fix using tools in home agent 2026-04-19 14:48:05 +02:00
Roberto
ea9094f47f Add llm providers 2026-04-19 00:32:12 +02:00
34 changed files with 3102 additions and 165 deletions

View File

@@ -21,6 +21,8 @@ OPENAI_API_KEY=
ANTHROPIC_API_KEY= ANTHROPIC_API_KEY=
GOOGLE_API_KEY= GOOGLE_API_KEY=
CEREBRAS_API_KEY= CEREBRAS_API_KEY=
GROQ_API_KEY=
DEEPSEEK_API_KEY=
# Default model used by any agent that does not have a specific override below. # Default model used by any agent that does not have a specific override below.
LLM_MODEL=gpt-5-mini LLM_MODEL=gpt-5-mini
@@ -54,6 +56,10 @@ LLM_MODEL_CLOUD_PROCESSOR=
# A small model (e.g. gpt-4o-mini) is sufficient. # A small model (e.g. gpt-4o-mini) is sufficient.
# LLM_MODEL_BRIEF_AGENT= # LLM_MODEL_BRIEF_AGENT=
# Task-brief-agent — per-task deep research (Stage 1 executive assistant).
# Needs tool-use + reasoning; a capable model recommended (e.g. gpt-4o, gemini-2.5-flash).
# LLM_MODEL_TASK_BRIEF_AGENT=
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat. # Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
LLM_MODEL_SETUP_AGENT= LLM_MODEL_SETUP_AGENT=

View File

@@ -0,0 +1,5 @@
## DEV
Run in DEV with command:
```
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload --log-config logging.conf
```

View File

@@ -0,0 +1,46 @@
"""Add token tracking columns for folder integration.
Revision ID: d6e3f4a5b6c7
Revises: 006
Create Date: 2026-05-11 00:00:00.000000
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = "d6e3f4a5b6c7"
down_revision: Union[str, None] = "006"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"agent_run_logs",
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
)
op.create_table(
"monthly_token_usage",
sa.Column("user_id", UUID(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("year_month", sa.String(7), nullable=False),
sa.Column("feature", sa.String(64), nullable=False),
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
sa.PrimaryKeyConstraint("user_id", "year_month", "feature"),
)
op.create_index(
"ix_monthly_token_usage_user_month",
"monthly_token_usage",
["user_id", "year_month"],
)
def downgrade() -> None:
op.drop_index("ix_monthly_token_usage_user_month", table_name="monthly_token_usage")
op.drop_table("monthly_token_usage")
op.drop_column("agent_run_logs", "tokens_used")

View File

@@ -0,0 +1,52 @@
"""Client agent — read-only tools for the clients table."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
@tool
async def list_clients(search: str = "", limit: int = 20) -> str:
"""List clients, optionally filtered by a name/email substring search.
search: optional substring to match against client name or email.
limit: max rows to return (default 20).
"""
filters: dict[str, Any] = {"limit": limit}
if search:
filters["search"] = search
result = await execute_on_client(action="select", table="clients", filters=filters)
rows = result.get("rows", [])
if not rows:
return "No clients found."
lines = [
f"- {r.get('name', '?')} (id: {r.get('id')}, email: {r.get('email', '')}, "
f"company: {r.get('company', '')})"
for r in rows
]
return f"Found {len(rows)} client(s):\n" + "\n".join(lines)
@tool
async def get_client(id: str) -> str:
"""Get full details for one client by UUID.
id: the client's UUID.
"""
if not id:
return "Client id is required."
result = await execute_on_client(action="get", table="clients", data={"id": id})
row = result.get("row") or result.get("rows", [None])[0] if result else None
if not row:
return f"Client '{id}' not found."
return f"Client details:\n{json.dumps(row, ensure_ascii=False, indent=2)}"
CLIENT_TOOLS: list[Any] = [list_clients, get_client]

168
app/agents/folder_agent.py Normal file
View File

@@ -0,0 +1,168 @@
"""Scoped file-read and search tools for the project folder feature."""
from __future__ import annotations
from langchain_core.tools import tool
from app.core.folder_indexer import _extract_docx_text, _extract_pdf_text
from app.core.ws_context import execute_on_client
# Cap returned slice size to keep tool output under control.
_MAX_RETURN_CHARS = 50_000
_MAX_SEARCH_MATCHES = 20
def _is_unsafe_path(rel: str) -> bool:
if not rel:
return True
norm = rel.replace("\\", "/")
if norm.startswith("/"):
return True
# Windows drive letter
if len(rel) >= 2 and rel[1] == ":":
return True
parts = norm.split("/")
return ".." in parts
async def _fetch_file(project_id: str, relative_path: str, offset: int, length: int) -> dict:
"""Return the raw Electron tool_result dict for a file read."""
return await execute_on_client(
action="read_project_folder_file",
data={
"projectId": project_id,
"relativePath": relative_path,
"offset": offset,
"length": length,
},
)
def _decode(result: dict) -> tuple[str, str, int]:
"""Decode a tool_result into (text, kind, total_size). For pdf/docx,
extracts text from base64. For images, returns a placeholder string.
For text, content is already a sliced utf-8 string.
"""
kind = result.get("kind", "text")
content = result.get("content", "") or ""
total = int(result.get("totalSize", 0) or 0)
if kind == "image":
return ("[Image file — cannot be navigated as text. See manifest summary.]", kind, total)
if kind == "pdf":
return (_extract_pdf_text(content), kind, total)
if kind == "docx":
return (_extract_docx_text(content), kind, total)
return (content, kind, total)
@tool
async def read_project_folder_file(
project_id: str,
relative_path: str,
offset: int = 0,
length: int = _MAX_RETURN_CHARS,
) -> str:
"""Read a slice of a file inside the project's linked folder.
Args:
project_id: project ID.
relative_path: path relative to the linked folder root.
offset: char offset to start reading from (0 = beginning).
length: max chars to return. Default 50000. Use smaller values to save tokens.
Returns text content slice with a header showing position. Header tells you
when more content is available; call again with the suggested next offset.
For PDF / DOCX files the backend extracts text first, then applies offset/length
on the extracted text. For images returns a placeholder; navigate with the
manifest summary instead.
"""
if _is_unsafe_path(relative_path):
return "Access denied"
result = await _fetch_file(project_id, relative_path, offset, length)
text, kind, total_size = _decode(result)
if not text and kind in ("missing", "error"):
return f"File not found or unreadable: {relative_path}"
if kind in ("pdf", "docx"):
# Backend extracted full text — apply offset/length on chars.
sliced = text[offset:offset + length]
slice_end = min(offset + length, len(text))
header = (
f"[file={relative_path} kind={kind} offset={offset} end={slice_end} "
f"totalChars={len(text)}]"
)
if slice_end < len(text):
header += f"\n[More content available — call again with offset={slice_end}.]"
return header + "\n" + sliced
if kind == "text":
slice_end = offset + len(text)
header = (
f"[file={relative_path} kind=text offset={offset} end={slice_end} "
f"totalBytes={total_size}]"
)
if slice_end < total_size:
header += f"\n[More content available — call again with offset={slice_end}.]"
return header + "\n" + text
# image or unknown
return text
@tool
async def search_project_folder_file(
project_id: str,
relative_path: str,
query: str,
context_lines: int = 3,
) -> str:
"""Search a project folder file for a query string (case-insensitive substring).
Args:
project_id: project ID.
relative_path: path relative to the linked folder root.
query: text to search for.
context_lines: number of lines of context around each match (default 3).
Returns matching line ranges with surrounding context and 1-based line numbers.
Capped at 20 matches; if more exist the header shows the total.
Works on text, code, markdown, PDF (extracted), and DOCX (extracted).
Images and binary files are not searchable.
"""
if _is_unsafe_path(relative_path):
return "Access denied"
if not query:
return "Empty query."
# For text we still need full file; pass length=very large.
result = await _fetch_file(project_id, relative_path, offset=0, length=10_000_000)
text, kind, _ = _decode(result)
if not text and kind in ("missing", "error"):
return f"File not found or unreadable: {relative_path}"
if kind == "image":
return "Cannot search inside images."
lines = text.splitlines()
q = query.lower()
matches = [i for i, line in enumerate(lines) if q in line.lower()]
if not matches:
return f"No matches for '{query}' in {relative_path}."
shown = matches[:_MAX_SEARCH_MATCHES]
snippets: list[str] = []
for i in shown:
start = max(0, i - context_lines)
end = min(len(lines), i + context_lines + 1)
block = "\n".join(f"{n + 1:5d}: {lines[n]}" for n in range(start, end))
snippets.append(block)
header = f"[file={relative_path} matches={len(matches)} showing={len(shown)} query='{query}']"
body = "\n---\n".join(snippets)
return header + "\n" + body
FOLDER_TOOLS = [read_project_folder_file, search_project_folder_file]

View File

@@ -1,13 +1,14 @@
"""Note agent — Markdown note management (list, get, create, update, delete).""" """Note agent — Markdown note management (list, get, create, update, propose edit)."""
from __future__ import annotations from __future__ import annotations
import asyncio
import re import re
from typing import Any from typing import Any
from langchain_core.tools import tool from langchain_core.tools import tool
from app.core.llm import embed from app.core.note_summarizer import generate_note_summary
from app.core.ws_context import execute_on_client from app.core.ws_context import execute_on_client
_UUID_RE = re.compile( _UUID_RE = re.compile(
@@ -19,9 +20,21 @@ def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value)) return bool(_UUID_RE.match(value))
def _fmt_summary(row: dict) -> str:
summary = (row.get("aiSummary") or row.get("ai_summary") or "").strip()
if summary:
return f"{summary}"
snippet = (row.get("content") or "")[:120].replace("\n", " ").strip()
return f"{snippet}" if snippet else ""
@tool @tool
async def list_notes(project_id: str = "") -> str: async def list_notes(project_id: str = "") -> str:
"""List notes, optionally scoped to a project by project_id.""" """List notes with AI summaries, optionally scoped to a project by project_id.
Returns id, title, and ai_summary for each note so you can decide which
note to read in full with get_note before creating or updating.
"""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client( result = await execute_on_client(
action="select", action="select",
@@ -31,7 +44,7 @@ async def list_notes(project_id: str = "") -> str:
rows = result.get("rows", []) rows = result.get("rows", [])
if not rows: if not rows:
return "No notes found." return "No notes found."
lines = [f"- {r['title']} (id: {r['id']})" for r in rows] lines = [f" - [{r['id']}] {r['title']}{_fmt_summary(r)}" for r in rows]
return f"Found {len(rows)} note(s):\n" + "\n".join(lines) return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
@@ -66,14 +79,10 @@ async def create_note(
}, },
) )
row = result["row"] row = result["row"]
# Index the note content in the vector store. note_id: str = row["id"]
vector = await embed(content) # Generate summary asynchronously — fire-and-forget.
await execute_on_client( asyncio.create_task(_refresh_summary(note_id, title, content))
action="vector_upsert", return f"Note created: '{row['title']}' (id: {note_id})."
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note created: '{row['title']}' (id: {row['id']})."
@tool @tool
@@ -82,7 +91,8 @@ async def update_note(
title: str = "", title: str = "",
content: str = "", content: str = "",
) -> str: ) -> str:
"""Update an existing note. Only pass fields that should change. """Update an existing note directly (no approval required).
Use propose_note_edit instead when human review is needed.
note_id: UUID of the note (required) note_id: UUID of the note (required)
If you need to preserve existing content, call get_note first. If you need to preserve existing content, call get_note first.
""" """
@@ -97,17 +107,63 @@ async def update_note(
data={"id": note_id, "updates": updates}, data={"id": note_id, "updates": updates},
) )
row = result["row"] row = result["row"]
# Re-index if content changed.
if content: if content:
vector = await embed(content) new_title = title or row.get("title", "")
await execute_on_client( asyncio.create_task(_refresh_summary(note_id, new_title, content))
action="vector_upsert",
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note updated: '{row['title']}' (id: {row['id']})." return f"Note updated: '{row['title']}' (id: {row['id']})."
@tool
async def propose_note_edit(
note_id: str,
edit_type: str,
proposed_content: str,
reasoning: str = "",
anchor_before: str = "",
anchor_text: str = "",
agent_id: str = "",
run_id: str = "",
) -> str:
"""Propose an AI edit to an existing note, pending human approval.
Use this instead of update_note when review_required is true.
The user will see the proposal highlighted before it is merged.
note_id: UUID of the target note (required)
edit_type: 'append' | 'insert' | 'replace'
- append: adds proposed_content at the end of the note
- insert: inserts proposed_content immediately after anchor_before text
- replace: replaces the first occurrence of anchor_text with proposed_content
proposed_content: the new Markdown text to add or substitute (required)
reasoning: brief explanation shown to the user (recommended)
anchor_before: for 'insert' — the text snippet that precedes the insertion point
anchor_text: for 'replace' — the exact text to be replaced
agent_id: agent identifier (for traceability)
run_id: run identifier (for traceability)
"""
if edit_type not in ("append", "insert", "replace"):
return f"Invalid edit_type '{edit_type}'. Use 'append', 'insert', or 'replace'."
result = await execute_on_client(
action="propose_note_edit",
data={
"noteId": note_id,
"type": edit_type,
"proposedContent": proposed_content,
"reasoning": reasoning or None,
"anchorBefore": anchor_before or None,
"anchorText": anchor_text or None,
"agentId": agent_id or None,
"runId": run_id or None,
},
)
edit_id = result.get("id", "?")
return (
f"Edit proposal created (id: {edit_id}) for note {note_id}. "
f"Status: pending user approval."
)
@tool @tool
async def delete_note(note_id: str) -> str: async def delete_note(note_id: str) -> str:
"""Delete a note permanently by its UUID.""" """Delete a note permanently by its UUID."""
@@ -115,11 +171,32 @@ async def delete_note(note_id: str) -> str:
return f"Note {note_id} deleted." return f"Note {note_id} deleted."
async def _refresh_summary(note_id: str, title: str, content: str) -> None:
"""Generate and persist the AI summary for a note. Fire-and-forget."""
try:
summary = await generate_note_summary(title, content)
if summary:
await execute_on_client(
action="update",
table="notes",
data={
"id": note_id,
"updates": {
"aiSummary": summary,
"aiSummaryUpdatedAt": int(__import__("time").time() * 1000),
},
},
)
except Exception:
pass # fire-and-forget; errors logged by generate_note_summary
NOTE_TOOLS: list[Any] = [ NOTE_TOOLS: list[Any] = [
list_notes, list_notes,
get_note, get_note,
create_note, create_note,
update_note, update_note,
propose_note_edit,
delete_note, delete_note,
] ]

View File

@@ -0,0 +1,63 @@
"""Relations agent — read-only tool wrapping MemoryMiddleware.query_relations."""
from __future__ import annotations
from typing import Any
from langchain_core.tools import tool
from app.core.memory_middleware import MemoryMiddleware
from app.db import async_session
# Injected at tool-factory time by _brief_research_tools(); not a module-level global.
# Each tool closure captures the user_id bound at factory time.
def make_query_relations_tool(user_id: str, trace_id: str | None = None) -> Any:
"""Return a query_relations tool bound to *user_id*."""
@tool
async def query_relations(
subject_label: str = "",
predicate: str = "",
object_label: str = "",
limit: int = 10,
) -> str:
"""Query the relational memory graph for entity relationships.
Returns rows where subject ↔ predicate ↔ object match the given filters.
All parameters are optional — omit to retrieve all relations up to limit.
subject_label: entity label on the left side (e.g. a client name, "Acme Corp").
predicate: relationship type (e.g. "mentioned_in", "works_at", "related_to").
object_label: entity label on the right side (e.g. a project name, "Website Redesign").
limit: max rows to return (default 10).
"""
import logging
logger = logging.getLogger(__name__)
logger.info(
"relations_agent: query_relations trace=%s user=%s subject=%r predicate=%r object=%r",
trace_id or "-", user_id, subject_label, predicate, object_label,
)
async with async_session() as db:
memory = MemoryMiddleware(db)
rows = await memory.query_relations(
user_id=user_id,
subject=subject_label or None,
predicate=predicate or None,
object_=object_label or None,
limit=limit,
)
if not rows:
return "No relational memory entries found for the given filters."
lines = [
f"- {r.subject_label} —[{r.predicate}]→ {r.object_label}"
+ (f" (confidence: {r.confidence:.2f})" if r.confidence is not None else "")
for r in rows
]
return f"Found {len(rows)} relation(s):\n" + "\n".join(lines)
return query_relations

View File

@@ -26,32 +26,141 @@ def _is_uuid(value: str) -> bool:
async def list_tasks( async def list_tasks(
project_id: str = "", project_id: str = "",
status: str = "", status: str = "",
priority: str = "",
assignee: str = "",
search: str = "", search: str = "",
order_by: str = "", order_by: str = "",
order_dir: str = "",
due_date_from: int = -1,
due_date_to: int = -1,
created_at_from: int = -1,
created_at_to: int = -1,
completed_at_from: int = -1,
completed_at_to: int = -1,
is_ai_suggested: int = -1,
limit: int = 50,
offset: int = 0,
) -> str: ) -> str:
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done), """List tasks with optional filters. Returns up to `limit` results (default 50).
a search string, or an order_by field name (dueDate|priority|createdAt)."""
project_id: UUID of the project to scope results to.
status: filter by status — todo | in_progress | done.
priority: filter by priority — high | medium | low.
assignee: substring to match against assignee names. OMIT unless the user explicitly
names a person or refers to themselves ("my tasks", "assigned to me", "mine").
Do NOT default to the current user.
search: substring search across title and description.
order_by: sort field — dueDate | priority | createdAt | completedAt.
order_dir: asc (default) | desc.
due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit.
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any.
limit: max rows to return (default 50). Use with offset to paginate.
offset: skip first N rows (default 0).
Tip — combine *_from and *_to for a closed range; pass only one for open-ended.
Tip — prefer count_tasks for "how many" questions to avoid listing rows.
Tip — for natural-language windows ("today", "tomorrow", "this week", "last month", etc.)
take due_date_from / due_date_to verbatim from the DATE CONTEXT block in the system prompt;
do not compute boundaries from the current UTC instant.
"""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client( filters: dict[str, Any] = {
action="select",
table="tasks",
filters={
"projectId": normalized_project_id or None, "projectId": normalized_project_id or None,
"status": status or None, "status": status or None,
"priority": priority or None,
"search": search or None, "search": search or None,
"orderBy": order_by or None, "orderBy": order_by or None,
}, "orderDir": order_dir or None,
) "limit": limit,
"offset": offset,
}
if assignee:
filters["assignee"] = assignee
if due_date_from != -1:
filters["dueDateFrom"] = due_date_from
if due_date_to != -1:
filters["dueDateTo"] = due_date_to
if created_at_from != -1:
filters["createdAtFrom"] = created_at_from
if created_at_to != -1:
filters["createdAtTo"] = created_at_to
if completed_at_from != -1:
filters["completedAtFrom"] = completed_at_from
if completed_at_to != -1:
filters["completedAtTo"] = completed_at_to
if is_ai_suggested != -1:
filters["isAiSuggested"] = is_ai_suggested
result = await execute_on_client(action="select", table="tasks", filters=filters)
rows = result.get("rows", []) rows = result.get("rows", [])
if not rows: if not rows:
return "No tasks found matching the given filters." return "No tasks found matching the given filters."
lines = [ lines = [
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})" f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, "
f"dueDate: {r.get('dueDate')}, completedAt: {r.get('completedAt')}, "
f"projectId: {r.get('projectId')}, id: {r['id']})"
for r in rows for r in rows
] ]
return f"Found {len(rows)} task(s):\n" + "\n".join(lines) return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
@tool
async def count_tasks(
project_id: str = "",
status: str = "",
priority: str = "",
assignee: str = "",
search: str = "",
due_date_from: int = -1,
due_date_to: int = -1,
created_at_from: int = -1,
created_at_to: int = -1,
completed_at_from: int = -1,
completed_at_to: int = -1,
is_ai_suggested: int = -1,
) -> str:
"""Count tasks matching the given filters without returning rows.
Use this instead of list_tasks for "how many" questions — it is much cheaper.
Same filter parameters as list_tasks (no limit/offset/order_by needed).
assignee: OMIT unless the user explicitly names a person or refers to themselves
("my tasks"). Do NOT default to the current user.
due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit.
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
Tip — for natural-language windows take due_date_from / due_date_to from the DATE CONTEXT block;
do not compute boundaries from the current UTC instant.
"""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
filters: dict[str, Any] = {
"projectId": normalized_project_id or None,
"status": status or None,
"priority": priority or None,
"search": search or None,
}
if assignee:
filters["assignee"] = assignee
if due_date_from != -1:
filters["dueDateFrom"] = due_date_from
if due_date_to != -1:
filters["dueDateTo"] = due_date_to
if created_at_from != -1:
filters["createdAtFrom"] = created_at_from
if created_at_to != -1:
filters["createdAtTo"] = created_at_to
if completed_at_from != -1:
filters["completedAtFrom"] = completed_at_from
if completed_at_to != -1:
filters["completedAtTo"] = completed_at_to
if is_ai_suggested != -1:
filters["isAiSuggested"] = is_ai_suggested
result = await execute_on_client(action="count", table="tasks", filters=filters)
return f"Task count: {result.get('count', 0)}"
@tool @tool
async def create_task( async def create_task(
title: str, title: str,
@@ -72,6 +181,8 @@ async def create_task(
due_date: Unix timestamp in milliseconds; 0 means no due date due_date: Unix timestamp in milliseconds; 0 means no due date
project_id: optional UUID of the parent project project_id: optional UUID of the parent project
is_ai_suggested: 1 if proactively suggested, 0 if user-requested is_ai_suggested: 1 if proactively suggested, 0 if user-requested
completedAt is set automatically when status is 'done'.
""" """
result = await execute_on_client( result = await execute_on_client(
action="insert", action="insert",
@@ -90,7 +201,7 @@ async def create_task(
row = result["row"] row = result["row"]
return ( return (
f"Task created: '{row['title']}' " f"Task created: '{row['title']}' "
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})" f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']}, projectId: {row.get('projectId')})"
) )
@@ -108,6 +219,10 @@ async def update_task(
"""Update fields on an existing task. Only pass fields you want to change. """Update fields on an existing task. Only pass fields you want to change.
task_id: the task's UUID (required) task_id: the task's UUID (required)
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
completedAt is managed automatically:
- setting status to 'done' records the current timestamp
- changing status away from 'done' clears completedAt
""" """
updates: dict[str, Any] = {} updates: dict[str, Any] = {}
if title: if title:
@@ -130,7 +245,7 @@ async def update_task(
data={"id": task_id, "updates": updates}, data={"id": task_id, "updates": updates},
) )
row = result["row"] row = result["row"]
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})" return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']}, projectId: {row.get('projectId')})"
@tool @tool
@@ -141,21 +256,36 @@ async def delete_task(task_id: str) -> str:
@tool @tool
async def list_tasks_due_today() -> str: async def list_tasks_due_today(user_timezone: str = "UTC", include_done: bool = False) -> str:
"""List all tasks whose due date falls on today's date.""" """List all tasks whose due date falls on today's date.
now = datetime.now(tz=timezone.utc)
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000) user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York').
end_ms = start_ms + 86_400_000 - 1 # last ms of today Always pass the user's timezone so 'today' is computed in their local time.
include_done: set True to also include already-completed tasks due today (default False).
"""
try:
from zoneinfo import ZoneInfo
tz = ZoneInfo(user_timezone or "UTC")
except Exception:
tz = timezone.utc
now_local = datetime.now(tz=tz)
start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz)
start_ms = int(start_dt.timestamp() * 1000)
end_ms = start_ms + 86_400_000 - 1
filters: dict[str, Any] = {"dueDateFrom": start_ms, "dueDateTo": end_ms}
if not include_done:
filters["status"] = "todo"
result = await execute_on_client( result = await execute_on_client(
action="select", action="select",
table="tasks", table="tasks",
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms}, filters=filters,
) )
rows = result.get("rows", []) rows = result.get("rows", [])
if not rows: if not rows:
return "No tasks are due today." return "No tasks are due today."
lines = [ lines = [
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})" f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, "
f"projectId: {r.get('projectId')}, id: {r['id']})"
for r in rows for r in rows
] ]
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines) return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
@@ -193,7 +323,6 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
) )
row = result.get("row", {}) row = result.get("row", {})
row_author = row.get("author", author) row_author = row.get("author", author)
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
row_task_id = row.get("taskId") or row.get("task_id") or task_id row_task_id = row.get("taskId") or row.get("task_id") or task_id
row_comment_id = row.get("id", "unknown") row_comment_id = row.get("id", "unknown")
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})." return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
@@ -211,6 +340,7 @@ async def delete_task_comment(comment_id: str) -> str:
TASK_TOOLS: list[Any] = [ TASK_TOOLS: list[Any] = [
list_tasks, list_tasks,
count_tasks,
create_task, create_task,
update_task, update_task,
delete_task, delete_task,
@@ -222,6 +352,7 @@ TASK_TOOLS: list[Any] = [
TASK_READ_TOOLS: list[Any] = [ TASK_READ_TOOLS: list[Any] = [
list_tasks, list_tasks,
count_tasks,
list_tasks_due_today, list_tasks_due_today,
list_task_comments, list_task_comments,
] ]

View File

@@ -20,19 +20,128 @@ def _is_uuid(value: str) -> bool:
@tool @tool
async def list_timelines(project_id: str = "") -> str: async def list_timelines(
"""List timelines. Provide project_id to scope to a specific project.""" project_id: str = "",
type: str = "",
is_completed: int = -1,
is_ai_suggested: int = -1,
order_by: str = "",
order_dir: str = "",
date_from: int = -1,
date_to: int = -1,
created_at_from: int = -1,
created_at_to: int = -1,
completed_at_from: int = -1,
completed_at_to: int = -1,
limit: int = 50,
offset: int = 0,
) -> str:
"""List timeline events (milestones, checkpoints, activities) with optional filters.
project_id: UUID to scope results to a specific project.
type: filter by event type — milestone | checkpoint | activity.
is_completed: 0 = incomplete only, 1 = completed only, -1 = any (default).
is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any.
order_by: sort field — date (default) | createdAt | completedAt.
order_dir: asc (default) | desc.
date_from / date_to: ms epoch range for the event date. Use -1 to omit.
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
limit: max rows to return (default 50). Use with offset to paginate.
offset: skip first N rows (default 0).
Tip — combine *_from and *_to for a closed range; pass only one for open-ended.
Tip — prefer count_timelines for "how many" questions to avoid listing rows.
Tip — for natural-language windows ("today", "this week", "last month", etc.)
take date_from / date_to verbatim from the DATE CONTEXT block in the system prompt;
do not compute boundaries from the current UTC instant.
"""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client( filters: dict[str, Any] = {
action="select", "projectId": normalized_project_id or None,
table="timelines", "orderBy": order_by or None,
filters={"projectId": normalized_project_id or None}, "orderDir": order_dir or None,
) "limit": limit,
"offset": offset,
}
if type:
filters["type"] = type
if is_completed != -1:
filters["isCompleted"] = is_completed
if is_ai_suggested != -1:
filters["isAiSuggested"] = is_ai_suggested
if date_from != -1:
filters["dateFrom"] = date_from
if date_to != -1:
filters["dateTo"] = date_to
if created_at_from != -1:
filters["createdAtFrom"] = created_at_from
if created_at_to != -1:
filters["createdAtTo"] = created_at_to
if completed_at_from != -1:
filters["completedAtFrom"] = completed_at_from
if completed_at_to != -1:
filters["completedAtTo"] = completed_at_to
result = await execute_on_client(action="select", table="timelines", filters=filters)
rows = result.get("rows", []) rows = result.get("rows", [])
if not rows: if not rows:
return "No timelines found." return "No timeline events found."
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows] lines = [
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines) f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, "
f"completed: {bool(r.get('isCompleted'))}, completedAt: {r.get('completedAt')}, "
f"projectId: {r.get('projectId')}, id: {r['id']})"
for r in rows
]
return f"Found {len(rows)} timeline event(s):\n" + "\n".join(lines)
@tool
async def count_timelines(
project_id: str = "",
type: str = "",
is_completed: int = -1,
is_ai_suggested: int = -1,
date_from: int = -1,
date_to: int = -1,
created_at_from: int = -1,
created_at_to: int = -1,
completed_at_from: int = -1,
completed_at_to: int = -1,
) -> str:
"""Count timeline events matching the given filters without returning rows.
Use this instead of list_timelines for "how many" questions — it is much cheaper.
Same filter parameters as list_timelines (no limit/offset/order_by needed).
date_from / date_to: ms epoch range for the event date. Use -1 to omit.
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
Tip — for natural-language windows take date_from / date_to from the DATE CONTEXT block;
do not compute boundaries from the current UTC instant.
"""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
filters: dict[str, Any] = {"projectId": normalized_project_id or None}
if type:
filters["type"] = type
if is_completed != -1:
filters["isCompleted"] = is_completed
if is_ai_suggested != -1:
filters["isAiSuggested"] = is_ai_suggested
if date_from != -1:
filters["dateFrom"] = date_from
if date_to != -1:
filters["dateTo"] = date_to
if created_at_from != -1:
filters["createdAtFrom"] = created_at_from
if created_at_to != -1:
filters["createdAtTo"] = created_at_to
if completed_at_from != -1:
filters["completedAtFrom"] = completed_at_from
if completed_at_to != -1:
filters["completedAtTo"] = completed_at_to
result = await execute_on_client(action="count", table="timelines", filters=filters)
return f"Timeline event count: {result.get('count', 0)}"
@tool @tool
@@ -40,13 +149,19 @@ async def create_timeline(
project_id: str, project_id: str,
title: str, title: str,
date: int, date: int,
type: str = "milestone",
is_completed: int = 0,
is_ai_suggested: int = 0, is_ai_suggested: int = 0,
) -> str: ) -> str:
"""Create a project timeline (milestone). """Create a project timeline event.
project_id: REQUIRED UUID of the parent project project_id: REQUIRED UUID of the parent project
title: descriptive name for the milestone title: descriptive name for the event
date: Unix timestamp in milliseconds date: Unix timestamp in milliseconds for the event date
type: milestone (default) | checkpoint | activity
is_completed: 1 if already completed, 0 if not (default 0)
is_ai_suggested: 1 if proactively suggested, 0 if user-requested is_ai_suggested: 1 if proactively suggested, 0 if user-requested
completedAt is set automatically when is_completed is 1.
""" """
result = await execute_on_client( result = await execute_on_client(
action="insert", action="insert",
@@ -55,11 +170,13 @@ async def create_timeline(
"projectId": project_id, "projectId": project_id,
"title": title, "title": title,
"date": date, "date": date,
"type": type,
"isCompleted": is_completed,
"isAiSuggested": is_ai_suggested, "isAiSuggested": is_ai_suggested,
}, },
) )
row = result["row"] row = result["row"]
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})" return f"Timeline event created: '{row['title']}' (id: {row['id']}, date: {row['date']}, type: {row.get('type')})"
@tool @tool
@@ -67,52 +184,79 @@ async def update_timeline(
timeline_id: str, timeline_id: str,
title: str = "", title: str = "",
date: int = -1, date: int = -1,
is_completed: int = -1,
) -> str: ) -> str:
"""Update a timeline. Only pass fields that should change. """Update a timeline event. Only pass fields that should change.
timeline_id: UUID of the timeline (required) timeline_id: UUID of the event (required)
date: -1 means unchanged; any other value sets the new date (ms timestamp) date: -1 means unchanged; any other value sets the new date (ms timestamp)
is_completed: 0 = mark incomplete, 1 = mark complete, -1 = unchanged
completedAt is managed automatically:
- setting is_completed to 1 records the current timestamp
- setting is_completed to 0 clears completedAt
""" """
updates: dict[str, Any] = {} updates: dict[str, Any] = {}
if title: if title:
updates["title"] = title updates["title"] = title
if date != -1: if date != -1:
updates["date"] = date updates["date"] = date
if is_completed != -1:
updates["isCompleted"] = is_completed
result = await execute_on_client( result = await execute_on_client(
action="update", action="update",
table="timelines", table="timelines",
data={"id": timeline_id, "updates": updates}, data={"id": timeline_id, "updates": updates},
) )
row = result["row"] row = result["row"]
return f"Timeline updated: '{row['title']}' (id: {row['id']})" return f"Timeline event updated: '{row['title']}' (id: {row['id']})"
@tool @tool
async def delete_timeline(timeline_id: str) -> str: async def delete_timeline(timeline_id: str) -> str:
"""Delete a timeline permanently by its UUID.""" """Delete a timeline event permanently by its UUID."""
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id}) await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
return f"Timeline {timeline_id} deleted." return f"Timeline event {timeline_id} deleted."
@tool @tool
async def list_timelines_today() -> str: async def list_timelines_today(user_timezone: str = "UTC", include_completed: bool = True) -> str:
"""List all timeline events (milestones) whose date falls on today (UTC).""" """List all timeline events whose date falls on today.
now = datetime.now(tz=timezone.utc)
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000) user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York').
Always pass the user's timezone so 'today' is computed in their local time.
include_completed: set False to exclude already-completed events (default True).
"""
try:
from zoneinfo import ZoneInfo
tz = ZoneInfo(user_timezone or "UTC")
except Exception:
tz = timezone.utc
now_local = datetime.now(tz=tz)
start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz)
start_ms = int(start_dt.timestamp() * 1000)
end_ms = start_ms + 86_400_000 - 1 end_ms = start_ms + 86_400_000 - 1
filters: dict[str, Any] = {"dateFrom": start_ms, "dateTo": end_ms}
if not include_completed:
filters["isCompleted"] = 0
result = await execute_on_client( result = await execute_on_client(
action="select", action="select",
table="timelines", table="timelines",
filters={"dateFrom": start_ms, "dateTo": end_ms}, filters=filters,
) )
rows = result.get("rows", []) rows = result.get("rows", [])
if not rows: if not rows:
return "No timeline events today." return "No timeline events today."
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows] lines = [
f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, "
f"completed: {bool(r.get('isCompleted'))}, projectId: {r.get('projectId')}, id: {r['id']})"
for r in rows
]
return f"Timeline events today ({len(rows)}):\n" + "\n".join(lines) return f"Timeline events today ({len(rows)}):\n" + "\n".join(lines)
TIMELINE_TOOLS: list[Any] = [ TIMELINE_TOOLS: list[Any] = [
list_timelines, list_timelines,
count_timelines,
list_timelines_today, list_timelines_today,
create_timeline, create_timeline,
update_timeline, update_timeline,
@@ -121,5 +265,6 @@ TIMELINE_TOOLS: list[Any] = [
TIMELINE_READ_TOOLS: list[Any] = [ TIMELINE_READ_TOOLS: list[Any] = [
list_timelines, list_timelines,
count_timelines,
list_timelines_today, list_timelines_today,
] ]

View File

@@ -20,10 +20,13 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
from app.api.deps import get_current_user from app.api.deps import get_current_user
from app.billing.tier_manager import FEATURES from app.billing.tier_manager import FEATURES
from app.core.agent_runner import is_agent_running, run_local_agent from app.core.agent_runner import is_agent_running, run_local_agent
from app.core.device_manager import device_manager from app.core.device_manager import device_manager
from app.core.note_summarizer import generate_note_summary
from app.db import get_session from app.db import get_session
from app.models import AgentRunLog, LocalAgentConfig from app.models import AgentRunLog, LocalAgentConfig
from app.schemas import ( from app.schemas import (
@@ -230,3 +233,25 @@ async def trigger_agent_run(
) )
return _to_run_log_response(run_log) return _to_run_log_response(run_log)
# ── Note summary endpoint ──────────────────────────────────────────────────────
class NoteSummarizeRequest(BaseModel):
title: str
content: str
class NoteSummarizeResponse(BaseModel):
summary: str
@router.post("/notes/summarize", response_model=NoteSummarizeResponse)
async def summarize_note(
body: NoteSummarizeRequest,
current_user: UserProfile = Depends(get_current_user),
) -> NoteSummarizeResponse:
"""Generate an AI summary for a note. Used by the Electron backfill on startup."""
summary = await generate_note_summary(body.title, body.content)
return NoteSummarizeResponse(summary=summary)

View File

@@ -9,7 +9,7 @@ from __future__ import annotations
from typing import Any from typing import Any
from fastapi import APIRouter, Depends, Header, Request, status from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -96,3 +96,37 @@ async def list_invoices(
""" """
invoices = await stripe_service.list_invoices(current_user.id, db) invoices = await stripe_service.list_invoices(current_user.id, db)
return invoices return invoices
# ── Quota check ────────────────────────────────────────────────────────
from app.billing.quota import check_folder_quota, QuotaExceeded # noqa: E402
class QuotaCheckRequest(BaseModel):
feature: str
estimated_files: int
@router.post("/quota/check")
async def quota_check(
payload: QuotaCheckRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict:
"""Pre-flight folder quota check. 402 if tier limits would be exceeded."""
if payload.feature != "folder_index":
raise HTTPException(status_code=400, detail="Unknown feature")
try:
await check_folder_quota(
user_id=current_user.id,
tier=current_user.tier,
estimated_files=payload.estimated_files,
db=db,
)
except QuotaExceeded as exc:
raise HTTPException(
status_code=402,
detail={"reason": exc.reason, "message": str(exc)},
)
return {"ok": True}

View File

@@ -43,7 +43,8 @@ from app.api.routes.agent_setup import handle_journey_message, handle_journey_st
from app.config.settings import settings from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs from app.core.agent_runner import trigger_pending_runs
from app.core.brief_agent import run_home_brief, run_project_brief from app.core.brief_agent import run_home_brief, run_project_brief
from app.core.deep_agent import run_floating_stream, run_home_stream from app.core.deep_agent import run_floating_stream, run_home_stream, run_task_brief_research_stream
from app.core.output_formatter import extract_canvas_block
from app.core.device_manager import device_manager from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware from app.core.memory_middleware import MemoryMiddleware
from app.core.output_formatter import StreamFormatter from app.core.output_formatter import StreamFormatter
@@ -56,6 +57,10 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ws", tags=["device-ws"]) router = APIRouter(prefix="/ws", tags=["device-ws"])
# ── v7 folder index session state ─────────────────────────────────────
# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled }
_index_sessions: dict[str, dict] = {}
_HEARTBEAT_INTERVAL = 30 # seconds _HEARTBEAT_INTERVAL = 30 # seconds
_PONG_TIMEOUT = 10 # seconds — grace window after a ping _PONG_TIMEOUT = 10 # seconds — grace window after a ping
@@ -164,6 +169,11 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
_handle_brief_request(websocket, user_id, frame) _handle_brief_request(websocket, user_id, frame)
) )
elif frame_type == WsFrameType.task_brief_request:
asyncio.create_task(
_handle_task_brief_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.journey_start: elif frame_type == WsFrameType.journey_start:
asyncio.create_task( asyncio.create_task(
_handle_journey_start(websocket, user_id, frame) _handle_journey_start(websocket, user_id, frame)
@@ -174,6 +184,19 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
_handle_journey_message(websocket, user_id, frame) _handle_journey_message(websocket, user_id, frame)
) )
elif frame_type == WsFrameType.index_session_start:
asyncio.create_task(
_handle_index_session_start(websocket, user_id, frame)
)
elif frame_type == WsFrameType.index_file_batch:
asyncio.create_task(
_handle_index_file_batch(websocket, user_id, frame)
)
elif frame_type == WsFrameType.index_session_cancel:
await _handle_index_session_cancel(websocket, frame)
elif frame_type == "pong": elif frame_type == "pong":
# Heartbeat ack — nothing to do, connection is alive. # Heartbeat ack — nothing to do, connection is alive.
pass pass
@@ -205,11 +228,13 @@ async def _handle_home_request(
request_id = frame.get("request_id") or str(uuid4()) request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "") message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4()) session_id: str = frame.get("session_id") or str(uuid4())
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
logger.info( logger.info(
"device_ws: home_request_start user=%s req=%s session=%s msg=%s", "device_ws: home_request_start user=%s req=%s session=%s project=%s msg=%s",
user_id, user_id,
request_id, request_id,
session_id, session_id,
project_id,
message[:200], message[:200],
) )
@@ -226,6 +251,7 @@ async def _handle_home_request(
context: dict = { context: dict = {
"conversation_history": frame.get("conversation_history", []), "conversation_history": frame.get("conversation_history", []),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
"format_prefs": frame.get("format_prefs"),
**memory_context, **memory_context,
} }
@@ -233,7 +259,7 @@ async def _handle_home_request(
set_client_executor(executor) set_client_executor(executor)
response_chunks: list[str] = [] response_chunks: list[str] = []
try: try:
event_stream = run_home_stream(user_id, message, context) event_stream = run_home_stream(user_id, message, context, project_id=project_id)
formatter = StreamFormatter(request_id=request_id) formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream): async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json()) await websocket.send_text(ws_frame.model_dump_json())
@@ -293,8 +319,10 @@ async def _handle_floating_request(
) )
context: dict = { context: dict = {
"conversation_history": frame.get("conversation_history", []),
"scope": scope, "scope": scope,
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
"format_prefs": frame.get("format_prefs"),
**memory_context, **memory_context,
} }
@@ -380,6 +408,7 @@ async def _handle_brief_request(
context: dict = { context: dict = {
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
"format_prefs": frame.get("format_prefs"),
**memory_context, **memory_context,
} }
@@ -411,6 +440,98 @@ async def _handle_brief_request(
) )
# ── v6 Task Brief Handler ────────────────────────────────────────────
async def _handle_task_brief_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a task_brief_request frame — Stage-1 executive assistant deep research.
Streams the briefing markdown back to the client.
On stream_end, emits a ``canvas_draft`` mutation if the agent produced one.
"""
request_id = frame.get("request_id") or str(uuid4())
session_id = frame.get("session_id") or str(uuid4())
task_id: str = frame.get("task_id") or frame.get("taskId") or ""
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
logger.info(
"device_ws: task_brief_request_start user=%s req=%s task=%s project=%s [cache_miss]",
user_id, request_id, task_id, project_id,
)
if not task_id:
await websocket.send_text(
WsStreamEnd(request_id=request_id, error="task_id is required").model_dump_json()
)
return
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id,
f"task brief: {task_id}",
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
"format_prefs": frame.get("format_prefs"),
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_task_brief_research_stream(user_id, task_id, context, project_id=project_id)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
await websocket.send_text(ws_frame.model_dump_json())
elif ws_frame.type == "stream_start":
await websocket.send_text(ws_frame.model_dump_json())
# stream_end is emitted below with mutations — skip formatter's version
except Exception as exc:
logger.error(
"device_ws: task_brief_request failed user=%s req=%s task=%s: %s",
user_id, request_id, task_id, exc,
)
await websocket.send_text(
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
)
return
finally:
clear_client_executor()
# Extract canvas block then emit stream_end with optional mutations.
full_response = "".join(response_chunks)
_visible, canvas_content, canvas_kind = extract_canvas_block(full_response)
mutations: list[dict] = []
if canvas_content:
mutations.append({
"type": "canvas_draft",
"content": canvas_content,
"kind": canvas_kind,
})
await websocket.send_text(
WsStreamEnd(request_id=request_id, mutations=mutations or None).model_dump_json()
)
logger.info(
"device_ws: task_brief_request_end user=%s req=%s task=%s response_chars=%d canvas=%s",
user_id, request_id, task_id, len(full_response), canvas_kind or "none",
)
# ── v4 Journey Handlers ───────────────────────────────────────────── # ── v4 Journey Handlers ─────────────────────────────────────────────
@@ -468,6 +589,174 @@ async def _handle_journey_message(
clear_client_executor() clear_client_executor()
# ── v7 Folder Index Handlers ──────────────────────────────────────────
async def _handle_index_session_start(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Register a new folder index session. No response sent — client is declaring intent."""
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
project_id: str | None = frame.get("projectId") or frame.get("project_id")
total: int = int(frame.get("totalFiles") or frame.get("total_files") or 0)
if not session_id:
logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id)
return
_index_sessions[session_id] = {
"user_id": user_id,
"project_id": project_id,
"processed": 0,
"total": total,
"cancelled": False,
}
logger.info(
"device_ws: index_session_start user=%s session=%s project=%s total=%d",
user_id, session_id, project_id, total,
)
async def _handle_index_session_cancel(
websocket: WebSocket,
frame: dict,
) -> None:
"""Mark a session as cancelled and emit index_session_done(cancelled)."""
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
session = _index_sessions.get(session_id)
if session:
session["cancelled"] = True
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "cancelled",
}))
_index_sessions.pop(session_id, None)
logger.info("device_ws: index_session_cancel session=%s", session_id)
async def _handle_index_file_batch(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Process a batch of files for an index session, streaming results back."""
# Lazy imports to avoid heavy load at module startup.
from app.core.folder_indexer import ( # noqa: PLC0415
summarize_image,
summarize_pdf,
summarize_docx,
summarize_text,
)
from app.billing.tier_manager import tier_manager # noqa: PLC0415
from app.billing.quota import add_token_usage # noqa: PLC0415
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
files: list[dict] = frame.get("files", [])
session = _index_sessions.get(session_id)
if not session or session.get("cancelled"):
return
async with async_session() as db:
tier = await tier_manager.get_tier(user_id, db)
raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens")
cap: int | None = None if raw_cap == -1 else raw_cap
for file_info in files:
if session.get("cancelled"):
return
# Electron's toSnakeCase converts payload keys, so accept both forms.
rel_path: str = file_info.get("relPath") or file_info.get("rel_path") or ""
kind: str = file_info.get("kind") or "text"
content: str = file_info.get("content") or ""
ext: str = file_info.get("ext") or ""
mime: str = file_info.get("mime") or "application/octet-stream"
name: str = rel_path.split("/")[-1] or rel_path
try:
if kind == "image":
res = await summarize_image(image_b64=content, mime=mime)
elif kind == "pdf":
res = await summarize_pdf(pdf_b64=content, name=name)
elif kind == "docx":
res = await summarize_docx(docx_b64=content, name=name)
else:
res = await summarize_text(content=content, ext=ext, name=name)
except Exception as exc:
logger.warning(
"device_ws: index_file_batch summarize failed session=%s path=%s: %s",
session_id, rel_path, exc,
)
await websocket.send_text(json.dumps({
"type": WsFrameType.index_file_result,
"sessionId": session_id,
"relPath": rel_path,
"summary": None,
"tokensUsed": 0,
"error": str(exc),
}))
session["processed"] += 1
continue
# Account for token usage and check cap.
usage = await add_token_usage(
user_id=user_id,
feature="folder_index",
tokens=res.tokens_used,
db=db,
cap=cap,
)
await websocket.send_text(json.dumps({
"type": WsFrameType.index_file_result,
"sessionId": session_id,
"relPath": rel_path,
"summary": res.summary,
"tokensUsed": res.tokens_used,
}))
session["processed"] += 1
if usage.exhausted:
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "quota_exceeded",
}))
_index_sessions.pop(session_id, None)
logger.info(
"device_ws: index_session quota_exceeded user=%s session=%s",
user_id, session_id,
)
return
# After processing the batch, emit progress.
processed = session["processed"]
total = session["total"]
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_progress,
"sessionId": session_id,
"processed": processed,
"total": total,
}))
if processed >= total:
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "completed",
}))
_index_sessions.pop(session_id, None)
logger.info(
"device_ws: index_session_done completed user=%s session=%s processed=%d",
user_id, session_id, processed,
)
# ── Heartbeat ───────────────────────────────────────────────────────── # ── Heartbeat ─────────────────────────────────────────────────────────
async def _heartbeat_loop(websocket: WebSocket) -> None: async def _heartbeat_loop(websocket: WebSocket) -> None:

139
app/billing/quota.py Normal file
View File

@@ -0,0 +1,139 @@
"""Quota checks and atomic token-usage accounting for folder integration."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from sqlalchemy import select, update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from app.billing.tier_manager import TierManager
from app.models import MonthlyTokenUsage
from app.schemas import BillingTier
class QuotaExceeded(Exception):
"""Raised when a folder operation cannot proceed under the user's tier."""
def __init__(self, reason: str, message: str) -> None:
super().__init__(message)
self.reason = reason # "max_files" | "monthly_tokens"
@dataclass
class TokenUsageResult:
tokens_used: int
exhausted: bool
def _current_year_month() -> str:
return datetime.now(timezone.utc).strftime("%Y-%m")
_tier_manager = TierManager()
async def check_folder_quota(
*,
user_id: str,
tier: BillingTier,
estimated_files: int,
db: AsyncSession,
) -> None:
"""Raise QuotaExceeded if folder_max_files or folder_monthly_tokens
would be violated. -1 in either feature means unlimited."""
max_files = _tier_manager.get_feature_value(tier, "folder_max_files")
if max_files != -1 and estimated_files > max_files:
raise QuotaExceeded(
"max_files",
f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.",
)
cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens")
if cap == -1:
return
ym = _current_year_month()
row = (
await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == user_id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == "folder_index",
)
)
).scalar_one_or_none()
used = row.tokens_used if row else 0
if used >= cap:
raise QuotaExceeded(
"monthly_tokens",
f"Monthly token budget exhausted ({used}/{cap}); resets next month.",
)
async def add_token_usage(
*,
user_id: str,
feature: str,
tokens: int,
db: AsyncSession,
cap: int | None = None,
) -> TokenUsageResult:
"""Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature).
Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls
back to a read-then-write on other engines (e.g. aiosqlite in tests).
Returns post-update total and whether cap is exhausted.
"""
ym = _current_year_month()
# Detect dialect to choose between native upsert and portable fallback.
dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr]
if dialect_name == "postgresql":
# Native atomic upsert — production path.
stmt = (
pg_insert(MonthlyTokenUsage)
.values(
user_id=user_id,
year_month=ym,
feature=feature,
tokens_used=tokens,
)
.on_conflict_do_update(
index_elements=["user_id", "year_month", "feature"],
set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens},
)
.returning(MonthlyTokenUsage.tokens_used)
)
used: int = (await db.execute(stmt)).scalar_one()
await db.commit()
else:
# Portable fallback — used in tests (SQLite) and any non-PG engine.
row = (
await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == user_id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == feature,
)
)
).scalar_one_or_none()
if row is None:
row = MonthlyTokenUsage(
user_id=user_id,
year_month=ym,
feature=feature,
tokens_used=tokens,
)
db.add(row)
else:
row.tokens_used += tokens
await db.commit()
await db.refresh(row)
used = row.tokens_used
exhausted = cap is not None and cap != -1 and used >= cap
return TokenUsageResult(tokens_used=used, exhausted=exhausted)

View File

@@ -29,6 +29,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": False, # batch queue (Phase 2) "realtime_extraction": False, # batch queue (Phase 2)
"relational_memory": False, # relational tier (Phase 3) — Pro+ "relational_memory": False, # relational tier (Phase 3) — Pro+
"proactive_mining": False, # Power+ only (Phase 5) "proactive_mining": False, # Power+ only (Phase 5)
"folder_max_files": 200,
"folder_monthly_tokens": 100_000,
}, },
"pro": { "pro": {
"agents": -1, # unlimited "agents": -1, # unlimited
@@ -41,6 +43,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": True, # fire-and-forget asyncio.create_task "realtime_extraction": True, # fire-and-forget asyncio.create_task
"relational_memory": True, # person/project predicates "relational_memory": True, # person/project predicates
"proactive_mining": False, # Power+ only (Phase 5) "proactive_mining": False, # Power+ only (Phase 5)
"folder_max_files": 5000,
"folder_monthly_tokens": 2_000_000,
}, },
"power": { "power": {
"agents": -1, "agents": -1,
@@ -53,6 +57,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": True, "realtime_extraction": True,
"relational_memory": True, # all predicates incl. custom "relational_memory": True, # all predicates incl. custom
"proactive_mining": True, # scheduled pattern mining (Phase 5) "proactive_mining": True, # scheduled pattern mining (Phase 5)
"folder_max_files": -1, # unlimited
"folder_monthly_tokens": -1, # unlimited
}, },
"team": { "team": {
"agents": -1, "agents": -1,
@@ -65,6 +71,8 @@ FEATURES: dict[str, dict[str, Any]] = {
"realtime_extraction": True, "realtime_extraction": True,
"relational_memory": True, # all predicates incl. custom "relational_memory": True, # all predicates incl. custom
"proactive_mining": True, # scheduled pattern mining (Phase 5) "proactive_mining": True, # scheduled pattern mining (Phase 5)
"folder_max_files": -1, # unlimited
"folder_monthly_tokens": -1, # unlimited
}, },
} }
@@ -123,6 +131,13 @@ class TierManager:
) )
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def get_feature_value(self, tier: BillingTier, feature: str) -> int:
"""Return integer feature value for tier. -1 means unlimited."""
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
if not isinstance(value, int):
return 0
return value
# ── Rate limiting ──────────────────────────────────────────────────── # ── Rate limiting ────────────────────────────────────────────────────
def get_rate_limit(self, tier: BillingTier) -> int: def get_rate_limit(self, tier: BillingTier) -> int:

View File

@@ -16,6 +16,8 @@ class Settings(BaseSettings):
ANTHROPIC_API_KEY: str = "" ANTHROPIC_API_KEY: str = ""
GOOGLE_API_KEY: str = "" GOOGLE_API_KEY: str = ""
CEREBRAS_API_KEY: str = "" CEREBRAS_API_KEY: str = ""
GROQ_API_KEY: str = ""
DEEPSEEK_API_KEY: str = ""
LLM_MODEL: str = "gpt-4o" LLM_MODEL: str = "gpt-4o"
LLM_EMBED_MODEL: str = "text-embedding-3-small" LLM_EMBED_MODEL: str = "text-embedding-3-small"
@@ -27,6 +29,7 @@ class Settings(BaseSettings):
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner) LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner) LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs) LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs)
LLM_MODEL_TASK_BRIEF_AGENT: str = "" # task-brief-agent (per-task deep research)
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide) LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining) LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)

View File

@@ -658,9 +658,14 @@ async def run_local_agent(
# ── Phase B: single LLM call ───────────────────────── # ── Phase B: single LLM call ─────────────────────────
extraction_rules = _get_extraction_rules(agent_config, content_type) extraction_rules = _get_extraction_rules(agent_config, content_type)
no_match_behavior = _get_no_match_behavior(agent_config) no_match_behavior = _get_no_match_behavior(agent_config)
global_rules_lines = "\n".join( base_global_rules = list(agent_config.get("global_rules", []))
f"- {r}" for r in agent_config.get("global_rules", []) if "notes" in config.data_types:
base_global_rules.append(
"For notes: when updating an existing note use `propose_note_edit` "
"(type=append/insert/replace) so the user can review AI changes. "
"Only call `update_note` for complete content replacement without review."
) )
global_rules_lines = "\n".join(f"- {r}" for r in base_global_rules)
metadata_section = _format_metadata(preprocessed.metadata) metadata_section = _format_metadata(preprocessed.metadata)
system_prompt = compile_prompt( system_prompt = compile_prompt(

View File

@@ -0,0 +1,59 @@
"""In-process TTL buffer for per-session LangChain message history.
Stores the full message list (including AIMessage with tool_calls and ToolMessage)
keyed by (user_id, session_id), so agents can reconstruct tool-call context across
conversation turns without it being lossy through the wire.
Single-process only. For multi-worker deployments, replace the _SessionBuffer
implementation with one backed by Redis (serialize LangChain messages to dicts via
message_to_dict / messages_from_dict from langchain_core.messages).
"""
from __future__ import annotations
import time
from threading import Lock
from langchain_core.messages import BaseMessage
SESSION_TTL_SECONDS = 1800 # 30-minute idle expiry
MAX_MESSAGES_PER_SESSION = 80 # cap to avoid unbounded memory growth
class _SessionBuffer:
def __init__(self) -> None:
self._store: dict[tuple[str, str], tuple[float, list[BaseMessage]]] = {}
self._lock = Lock()
def _evict_stale(self) -> None:
now = time.monotonic()
stale = [k for k, (ts, _) in self._store.items() if now - ts > SESSION_TTL_SECONDS]
for k in stale:
del self._store[k]
def get(self, user_id: str, session_id: str) -> list[BaseMessage] | None:
key = (user_id, session_id)
with self._lock:
entry = self._store.get(key)
if entry is None:
return None
ts, msgs = entry
if time.monotonic() - ts > SESSION_TTL_SECONDS:
del self._store[key]
return None
self._store[key] = (time.monotonic(), msgs)
return list(msgs)
def set(self, user_id: str, session_id: str, messages: list[BaseMessage]) -> None:
key = (user_id, session_id)
capped = messages[-MAX_MESSAGES_PER_SESSION:]
with self._lock:
self._evict_stale()
self._store[key] = (time.monotonic(), capped)
def clear(self, user_id: str, session_id: str) -> None:
with self._lock:
self._store.pop((user_id, session_id), None)
# Module-level singleton — same pattern as _pending_states in api/app/api/routes/auth.py
session_buffer = _SessionBuffer()

View File

@@ -21,6 +21,7 @@ from app.core.deep_agent import (
_relational_memory_injection, _relational_memory_injection,
_run_single_agent_stream, _run_single_agent_stream,
_trace_id_from_context, _trace_id_from_context,
build_brief_multi_project_manifest,
) )
from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback
@@ -159,6 +160,8 @@ async def run_home_brief(
Yields (event_type, data) tuples identical to _run_single_agent_stream. Yields (event_type, data) tuples identical to _run_single_agent_stream.
Do NOT post-process output through _normalize_tagged_list_lines. Do NOT post-process output through _normalize_tagged_list_lines.
""" """
from app.agents.folder_agent import FOLDER_TOOLS
trace_id = _trace_id_from_context(context) trace_id = _trace_id_from_context(context)
today = date.today().isoformat() today = date.today().isoformat()
language = _resolve_language(context) language = _resolve_language(context)
@@ -171,7 +174,10 @@ async def run_home_brief(
if today not in system_prompt: if today not in system_prompt:
system_prompt += f"\nToday is {today}." system_prompt += f"\nToday is {today}."
tools = _build_read_tools(user_id, trace_id) brief_manifest = await build_brief_multi_project_manifest()
system_prompt = system_prompt + ("\n\n" + brief_manifest if brief_manifest else "")
tools = [*_build_read_tools(user_id, trace_id), *FOLDER_TOOLS]
async for event in _run_single_agent_stream( async for event in _run_single_agent_stream(
user_id=user_id, user_id=user_id,
system_prompt=system_prompt, system_prompt=system_prompt,

View File

@@ -12,11 +12,14 @@ from typing import Any, Literal
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool from langchain_core.tools import tool
from app.agents.client_agent import CLIENT_TOOLS
from app.agents.note_agent import NOTE_TOOLS from app.agents.note_agent import NOTE_TOOLS
from app.agents.project_agent import PROJECT_TOOLS from app.agents.project_agent import PROJECT_TOOLS
from app.agents.relations_agent import make_query_relations_tool
from app.agents.task_agent import TASK_TOOLS from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS from app.agents.timeline_agent import TIMELINE_TOOLS
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context from app.core.agent_session_buffer import session_buffer
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
from app.core.llm import get_agent_llm, model_for_agent from app.core.llm import get_agent_llm, model_for_agent
from app.core.memory_middleware import MemoryMiddleware from app.core.memory_middleware import MemoryMiddleware
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
@@ -24,6 +27,8 @@ from app.db import async_session
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_HISTORY_TURNS = 20
FloatingDomainType = Literal["task", "timeline", "project", "node"] FloatingDomainType = Literal["task", "timeline", "project", "node"]
FloatingDomainSection = Literal["task", "timeline", "note"] FloatingDomainSection = Literal["task", "timeline", "note"]
@@ -55,6 +60,182 @@ def _language_instruction(context: dict[str, Any]) -> str:
f"All your output text must be written in {lang}." f"All your output text must be written in {lang}."
) )
MANIFEST_TOKEN_BUDGET = 3000 # rough budget for <linked_folder> block
def format_folder_manifest(manifest: dict | None) -> str:
"""Format a folder manifest into the <linked_folder> block.
Truncates by mtime DESC if estimated tokens exceed MANIFEST_TOKEN_BUDGET.
Returns empty string if manifest is None or has no files.
"""
if not manifest or not manifest.get("files"):
return ""
files = list(manifest["files"])
files.sort(key=lambda f: f.get("mtimeMs", 0), reverse=True)
header = (
f"<linked_folder>\npath: {manifest.get('folderPath', '?')} "
f"({len(files)} files, scanned {manifest.get('lastScannedAt', '?')})\nfiles:\n"
)
footer_template = "{} more files omitted, use read_project_folder_file to access by path\n</linked_folder>"
char_budget = MANIFEST_TOKEN_BUDGET * 4 # ~4 chars/token
body = ""
included = 0
for f in files:
line = f"- /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}\n"
if len(header) + len(body) + len(line) + len(footer_template.format(0)) > char_budget:
break
body += line
included += 1
omitted = len(files) - included
if omitted > 0:
return header + body + footer_template.format(omitted)
return header + body + "</linked_folder>"
async def _fetch_project_manifest(project_id: str) -> dict | None:
"""Fetch manifest from Electron via execute_on_client. Returns None if unlinked or error."""
from app.core.ws_context import execute_on_client
try:
result = await execute_on_client(
action="read_project_folder_manifest",
data={"projectId": project_id},
)
if not result or not result.get("folderPath"):
return None
return result
except Exception:
return None
async def build_brief_multi_project_manifest() -> str:
"""Build a compact multi-project manifest for the daily brief agent.
Calls execute_on_client('list_projects_with_folder_manifests') and keeps
the top 5 most-recently-modified files per project.
"""
try:
result = await execute_on_client(
action="list_projects_with_folder_manifests",
data={},
)
except Exception:
return ""
projects = (result or {}).get("projects") or []
if not projects:
return ""
blocks: list[str] = ["<linked_folders>"]
any_entry = False
for p in projects:
all_files = p.get("files", []) or []
files = sorted(all_files, key=lambda f: f.get("mtimeMs", 0), reverse=True)[:5]
blocks.append(f"project: {p.get('projectName','?')} [{p.get('projectId','?')}]")
blocks.append(f" path: {p.get('folderPath','?')} (scanned {p.get('lastScannedAt','?')})")
if not all_files:
blocks.append(" (no indexed files yet — folder is linked but empty or unscanned)")
else:
for f in files:
blocks.append(f" - /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}")
if len(all_files) > 5:
blocks.append(f"{len(all_files) - 5} more files (use read_project_folder_file by relPath)")
any_entry = True
if not any_entry:
return ""
blocks.append("</linked_folders>")
return "\n".join(blocks)
def _datetime_context_injection(context: dict[str, Any]) -> str:
"""Build a comprehensive DATE CONTEXT block with pre-computed ms-epoch boundaries for common ranges."""
fp = context.get("format_prefs")
if not fp or not isinstance(fp, dict):
return ""
try:
from zoneinfo import ZoneInfo
from datetime import datetime as _dt, timezone as _utc, timedelta as _td
tz_name: str = str(fp.get("timezone") or "UTC")
now_iso: str = str(fp.get("now_iso") or "")
date_fmt: str = str(fp.get("date_format") or "dd/MM/yyyy")
time_fmt: str = str(fp.get("time_format") or "24h")
tz = ZoneInfo(tz_name)
if now_iso:
now_utc = _dt.fromisoformat(now_iso.replace("Z", "+00:00"))
else:
now_utc = _dt.now(_utc.utc)
now_ms = int(now_utc.timestamp() * 1000)
now_local = now_utc.astimezone(tz)
now_local_str = now_local.strftime("%Y-%m-%d %H:%M")
weekday_str = now_local.strftime("%A")
y, m, d = now_local.year, now_local.month, now_local.day
def _day(year: int, month: int, day: int) -> tuple[int, int]:
s = _dt(year, month, day, tzinfo=tz)
e = s + _td(days=1)
return int(s.timestamp() * 1000), int(e.timestamp() * 1000) - 1
def _between(start: "_dt", end_excl: "_dt") -> tuple[int, int]:
return int(start.timestamp() * 1000), int(end_excl.timestamp() * 1000) - 1
today_s, today_e = _day(y, m, d)
yd = now_local - _td(days=1)
yesterday_s, yesterday_e = _day(yd.year, yd.month, yd.day)
tm = now_local + _td(days=1)
tomorrow_s, tomorrow_e = _day(tm.year, tm.month, tm.day)
# ISO week (MonSun)
monday = _dt(y, m, d, tzinfo=tz) - _td(days=now_local.weekday())
last_monday = monday - _td(weeks=1)
next_monday = monday + _td(weeks=1)
this_week_s, this_week_e = _between(monday, next_monday)
last_week_s, last_week_e = _between(last_monday, monday)
next_week_s, next_week_e = _between(next_monday, next_monday + _td(weeks=1))
# Calendar months
this_m_start = _dt(y, m, 1, tzinfo=tz)
next_m_start = _dt(y + (m // 12), m % 12 + 1, 1, tzinfo=tz)
last_m_start = _dt(y - (1 if m == 1 else 0), 12 if m == 1 else m - 1, 1, tzinfo=tz)
next2_m = next_m_start.month % 12 + 1
next2_y = next_m_start.year + (1 if next_m_start.month == 12 else 0)
next2_m_start = _dt(next2_y, next2_m, 1, tzinfo=tz)
this_month_s, this_month_e = _between(this_m_start, next_m_start)
last_month_s, last_month_e = _between(last_m_start, this_m_start)
next_month_s, next_month_e = _between(next_m_start, next2_m_start)
# Calendar years
this_yr_s, this_yr_e = _between(_dt(y, 1, 1, tzinfo=tz), _dt(y + 1, 1, 1, tzinfo=tz))
last_yr_s, last_yr_e = _between(_dt(y - 1, 1, 1, tzinfo=tz), _dt(y, 1, 1, tzinfo=tz))
sunday = monday + _td(days=6)
last_sunday = last_monday + _td(days=6)
next_sunday = next_monday + _td(days=6)
return (
f"\n\nDATE CONTEXT (timezone: {tz_name}, dateFormat: {date_fmt}, timeFormat: {time_fmt})\n"
f"now_local: {now_local_str} ({weekday_str})\n"
f"now_ms: {now_ms}\n\n"
f"today [{today_s}, {today_e}] {y:04d}-{m:02d}-{d:02d}\n"
f"tomorrow [{tomorrow_s}, {tomorrow_e}] {tm.strftime('%Y-%m-%d')}\n"
f"yesterday [{yesterday_s}, {yesterday_e}] {yd.strftime('%Y-%m-%d')}\n"
f"this_week [{this_week_s}, {this_week_e}] {monday.strftime('%Y-%m-%d')}{sunday.strftime('%Y-%m-%d')} (MonSun)\n"
f"last_week [{last_week_s}, {last_week_e}] {last_monday.strftime('%Y-%m-%d')}{last_sunday.strftime('%Y-%m-%d')}\n"
f"next_week [{next_week_s}, {next_week_e}] {next_monday.strftime('%Y-%m-%d')}{next_sunday.strftime('%Y-%m-%d')}\n"
f"this_month [{this_month_s}, {this_month_e}] {y:04d}-{m:02d}\n"
f"last_month [{last_month_s}, {last_month_e}] {last_m_start.strftime('%Y-%m')}\n"
f"next_month [{next_month_s}, {next_month_e}] {next_m_start.strftime('%Y-%m')}\n"
f"this_year [{this_yr_s}, {this_yr_e}] {y:04d}\n"
f"last_year [{last_yr_s}, {last_yr_e}] {y - 1:04d}\n\n"
f"When calling list_tasks_due_today or list_timelines_today, always pass user_timezone=\"{tz_name}\".\n"
f"When presenting dates, format using dateFormat={date_fmt} and timeFormat={time_fmt}."
)
except Exception:
return ""
def _proactive_hints_injection(context: dict[str, Any]) -> str: def _proactive_hints_injection(context: dict[str, Any]) -> str:
"""Return a system-prompt paragraph listing proactive behavioral hints. """Return a system-prompt paragraph listing proactive behavioral hints.
@@ -87,27 +268,203 @@ def _relational_memory_injection(context: dict[str, Any]) -> str:
return section return section
_HOME_SYSTEM_PROMPT = ( _IDENTITY_KEYS = ("user_name", "job_role", "industry", "primary_use_case", "tone_preference")
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
"Always use tools for factual data retrieval before answering. "
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
"Return markdown and use tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>. "
"When listing tasks or timelines, each id tag must be on its own line with no prefix/suffix text. "
"Never put titles, priorities, or dates on the same line as <task> or <timeline> tags. "
"For questions about upcoming timelines (e.g. 'prossimi eventi'), include only future items in the current month unless the user asks a different range. "
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
)
_FLOATING_SYSTEM_PROMPT = (
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. " def _user_identity_injection(context: dict[str, Any]) -> str:
"Stay focused on the floating scope in context.scope and answer concisely. " """Return a compact user-profile block from core memory onboarding fields.
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
"Always use tools for factual data retrieval before answering. " Returns empty string when no onboarding keys are present.
"When the user asks to remember, forget, or update what you know about them, use memory tools. " """
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. " core = context.get("core_memory") or {}
) parts: list[str] = []
for key in _IDENTITY_KEYS:
val = (core.get(key) or "").strip()
if val:
parts.append(f"- {key}: {val}")
if not parts:
return ""
return "\n\nUser profile:\n" + "\n".join(parts)
def _request_context_block(context: dict[str, Any]) -> str:
"""Return a small block with per-request scope and resolved project context."""
parts: list[str] = []
scope = context.get("scope")
if scope and isinstance(scope, dict):
parts.append(f"scope: {json.dumps(scope, ensure_ascii=True)}")
resolved = context.get("resolved_project_id")
if resolved and isinstance(resolved, str):
parts.append(f"resolved_project_id: {resolved}")
return "\n".join(parts)
_HOME_SYSTEM_PROMPT = """\
You are adiuvAI's home executive assistant.{user_identity}
You are not a chatbot — you are a proactive partner who runs ahead of the user, anticipates what they need next, and closes every reply with a concrete next step or a clarifying question.
# How you work
- Use tools before answering anything factual. Never guess counts, dates, or status.
- Prefer parallel tool calls when the questions are independent (e.g. counts per status). Chain calls when one result feeds the next.
- After delivering the answer, propose the next useful action: a follow-up task to draft, a deadline at risk, a project to triage, a person to remind. Use what you know about the user (job role, industry, primary use case) to make the suggestion relevant.
- Match the user's tone preference. Default to warm-but-direct; stay concise.
- When the user asks to remember, forget, or update something, use memory tools.
# Filter discipline
- Never set the `assignee` filter on list_tasks/count_tasks unless the user explicitly names a person ("Marco's tasks") or refers to themselves ("my tasks", "assigned to me", "mine").
- The user's own name in the User profile block is for context only — it is NOT a default filter.
- When in doubt, omit `assignee` and return the global result.
# Output format
Return markdown. Reference entities with these tags exactly — one id per tag, each tag on its own line, no prefix/suffix text on the same line:
<project>id</project> <task>id</task> <note>id</note> <timeline>id</timeline>
When the answer contains a list of entities (any of the tags above), structure the reply as three blocks separated by blank lines:
1. One short intro line stating what is coming (count + scope, e.g. "Ecco i tuoi 18 task ad alta priorità:"). Match the user's language.
2. All entity tags, one per line, consecutive, no prose interleaved. Do NOT put titles, dates, priorities, or any descriptive text on the same line as a tag or between tags.
3. One short closing recap (12 sentences) that points out a pattern, risk, or insight noticed in the list, and ends with a concrete next step or clarifying question.
For single-entity answers skip blocks 1 and 3 if they would be redundant; just emit the tag.
For analytical answers (status overviews, breakdowns by category/priority/project, comparisons, trends, "resoconto", "panoramica") consider returning a chart block when it communicates the answer faster than prose. The decision is yours — skip charts for trivial single-number answers. Schema:
<chart>{{"chartType":"pie|bar|line|area|radar|radial","title":"...","data":[{{"name":"...","value":N}},...], "config":{{"value":{{"label":"...","color":"var(--chart-1)"}} }} }}</chart>
- pie for share-of-total breakdowns; bar for category comparisons; line/area for time series; radar for multi-dimension.
- data rows must include a "name" field; numeric series keys must match config keys.
- Use var(--chart-1) through var(--chart-5) for colors, cycling 1-5 in series order. Do NOT wrap in hsl() or oklch() — these are complete CSS values already.
For upcoming-timeline questions ("prossimi eventi"), include only future items in the current month unless the user asks otherwise.
# Date filtering
{date_context}
When filtering tasks/timelines/notes by date, take dueDateFrom / dueDateTo (ms epoch UTC) verbatim from the DATE CONTEXT boundary table above. Do NOT compute boundaries from now_ms yourself.
For specific dates not listed, compute local-midnight in the user timezone and convert to UTC ms.
For "today" / "tomorrow" queries, prefer list_tasks_due_today / list_timelines_today with user_timezone from DATE CONTEXT.
# Language
{language_instruction}
# Known people & projects
{relational_memory}
# Behavioral hints
{proactive_hints}
# Request context
{request_context}\
"""
_FLOATING_SYSTEM_PROMPT = """\
You are adiuvAI's floating executive assistant.{user_identity}
You are pinned to a specific entity (task, timeline event, project, or note) and you stay strictly within that scope.
Be a proactive partner: anticipate the next useful action and close with a concrete suggestion or a clarifying question — but stay terse, one short paragraph at most.
# How you work
- Use tools before answering anything factual. Never guess.
- Stay in the floating scope (see Request context). If the user asks something outside scope, answer briefly and suggest opening the home assistant.
- Match the user's tone preference. Default to warm-but-direct.
- When the user asks to remember, forget, or update something, use memory tools.
# Filter discipline
- Never set the `assignee` filter on list_tasks/count_tasks unless the user explicitly names a person ("Marco's tasks") or refers to themselves ("my tasks", "assigned to me", "mine").
- The user's own name in the User profile block is for context only — it is NOT a default filter.
- When in doubt, omit `assignee` and return the global result.
# Output format
Plain text only. Do NOT output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed-id wrappers, and do NOT output <chart> blocks — those are for the home assistant.
# Date filtering
{date_context}
When filtering by date, take dueDateFrom / dueDateTo (ms epoch UTC) verbatim from the DATE CONTEXT boundary table above. Do NOT compute boundaries from now_ms yourself.
For specific dates not listed, compute local-midnight in the user timezone and convert to UTC ms.
# Language
{language_instruction}
# Known people & projects
{relational_memory}
# Behavioral hints
{proactive_hints}
# Request context
{request_context}\
"""
_TASK_BRIEF_RESEARCH_SYSTEM_PROMPT = """\
You are an executive assistant preparing a briefing dossier for your principal before they act on a specific task.
Your job: gather all relevant context, synthesize it into a tight actionable dossier, and — if the task requires writing (email, message, document) — produce a ready-to-use draft.{user_identity}
# Research workflow
Follow these steps in order, using tools:
1. Read the task fully (title, description, due date, priority, status, project, comments).
2. Fetch the parent project (`get_project`) to understand scope, aiSummary, and any linked client.
3. If the project has a clientId: call `get_client(id)` to retrieve full client details.
4. Call `query_relations` (subject_label=client_name or task subject) to find cross-project connections — e.g. the same client appearing in multiple projects.
5. Search associative memory (`search_associative`) and archival memory (`archival_memory_search`) using the task title + client name as query phrases to surface relevant past interactions.
6. Read core memory blocks for tone preference, language, and user style: `memory_get("tone_preference")`, `memory_get("language")`.
7. Determine task kind: is this a writing task (email reply, message, follow-up, proposal)? If yes, draft a ready-to-send piece.
# Output structure
Write the briefing in the user's language. Use this exact structure:
**What needs to be done**
(12 sentences, concrete and specific — what action the user must take)
**Context you should know**
(bullet points covering: client background, related projects, prior interactions, tone/style notes, any relevant deadlines or dependencies)
**Suggested first step**
(one specific, immediately actionable instruction)
If this is a writing task, append a canvas block at the very end:
<canvas kind="email|document|message">
...ready-to-use draft here...
</canvas>
Do NOT include the canvas block for non-writing tasks.
Do NOT repeat verbatim task fields the user already sees in the UI.
Be concrete — no vague advice. Every bullet should be a fact that changes what the user does.
# Date context
{date_context}
# Language
{language_instruction}
# Known people & projects
{relational_memory}
# Request context
{request_context}\
"""
_TASK_BRIEF_FOLLOWUP_SYSTEM_PROMPT = """\
You are an executive assistant continuing a conversation with your principal.
You have already prepared and delivered a research briefing for the active task. The user has read it.{user_identity}
Your briefing:
---
{briefing_context}
---
Continue from here. Do NOT repeat the briefing. Refer to it when relevant.
Help the user execute: edit drafts, refine wording, look up additional details, plan next steps.
Stay terse — your principal is a busy executive.
# Date context
{date_context}
# Language
{language_instruction}
# Known people & projects
{relational_memory}
# Request context
{request_context}\
"""
_FLOATING_DOMAIN_CLASSIFIER_PROMPT = ( _FLOATING_DOMAIN_CLASSIFIER_PROMPT = (
"You are a strict domain classifier for websocket floating requests. " "You are a strict domain classifier for websocket floating requests. "
@@ -217,10 +574,19 @@ def _session_id_from_context(context: dict[str, Any]) -> str | None:
return None return None
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]: def _build_system_prompt(name: str, fallback: str, context: dict[str, Any]) -> tuple[str, Any]:
sanitized = dict(context) """Fetch Langfuse template and compile all per-request slots into one system prompt."""
sanitized.pop("_debug", None) template, prompt_obj = get_prompt_or_fallback(name, fallback)
return sanitized text = compile_prompt(
template, prompt_obj,
date_context=_datetime_context_injection(context).strip(),
language_instruction=_language_instruction(context).strip(),
user_identity=_user_identity_injection(context).strip(),
relational_memory=_relational_memory_injection(context).strip(),
proactive_hints=_proactive_hints_injection(context).strip(),
request_context=_request_context_block(context),
)
return text, prompt_obj
_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]</\1>") _TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]</\1>")
@@ -476,6 +842,25 @@ def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
lines = [f"- {item}" for item in results] lines = [f"- {item}" for item in results]
return "Recall memory results:\n" + "\n".join(lines) return "Recall memory results:\n" + "\n".join(lines)
@tool
async def search_associative(query: str, limit: int = 5) -> str:
"""Semantic search across associative (archival) memory for a given query.
Use this to surface long-term memories related to a topic, client, or task
that may not appear in recent episodes.
query: natural-language search phrase.
limit: max results (default 5).
"""
logger.info("deep_agent: search_associative trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
async with async_session() as db:
memory = MemoryMiddleware(db)
results = await memory.search_archival(user_id, query, top_k=limit)
if not results:
return "No associative memory results found."
lines = [f"- {item}" for item in results]
return "Associative memory results:\n" + "\n".join(lines)
return [ return [
memory_list_blocks, memory_list_blocks,
memory_get, memory_get,
@@ -486,16 +871,33 @@ def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
archival_memory_insert, archival_memory_insert,
archival_memory_search, archival_memory_search,
conversation_search, conversation_search,
search_associative,
] ]
def _read_only_memory_tools(user_id: str, trace_id: str | None) -> list[Any]: def _read_only_memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
"""Return memory tools that only read — safe for the read-only brief-agent subset.""" """Return memory tools that only read — safe for the read-only brief-agent subset."""
all_mem = _memory_tools(user_id, trace_id) all_mem = _memory_tools(user_id, trace_id)
_read_names = {"memory_list_blocks", "memory_get", "archival_memory_search", "conversation_search"} _read_names = {
"memory_list_blocks", "memory_get", "archival_memory_search",
"conversation_search", "search_associative",
}
return [t for t in all_mem if t.name in _read_names] return [t for t in all_mem if t.name in _read_names]
def _brief_research_tools(user_id: str, trace_id: str | None) -> list[Any]:
"""Return the full tool palette for Stage-1 task brief research (read-only)."""
return [
*TASK_TOOLS,
*PROJECT_TOOLS,
*NOTE_TOOLS,
*TIMELINE_TOOLS,
*CLIENT_TOOLS,
*_read_only_memory_tools(user_id, trace_id),
make_query_relations_tool(user_id, trace_id),
]
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]: def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
return [*_all_tools(), *_memory_tools(user_id, trace_id)] return [*_all_tools(), *_memory_tools(user_id, trace_id)]
@@ -662,6 +1064,23 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[
return _infer_floating_domain_rule_based(message, context) return _infer_floating_domain_rule_based(message, context)
def _history_to_messages(history: list[dict[str, str]] | None) -> list[Any]:
if not history:
return []
turns = history[-MAX_HISTORY_TURNS:]
result: list[Any] = []
for turn in turns:
role = turn.get("role", "")
content = turn.get("content", "")
if not content:
continue
if role == "user":
result.append(HumanMessage(content=content))
elif role == "assistant":
result.append(AIMessage(content=content))
return result
async def _run_single_agent( async def _run_single_agent(
*, *,
user_id: str, user_id: str,
@@ -671,23 +1090,21 @@ async def _run_single_agent(
max_steps: int = 6, max_steps: int = 6,
langfuse_prompt: Any = None, langfuse_prompt: Any = None,
agent_name: str = "agent", agent_name: str = "agent",
conversation_history: list[dict[str, str]] | None = None,
) -> str: ) -> str:
trace_id = _trace_id_from_context(context) trace_id = _trace_id_from_context(context)
session_id = _session_id_from_context(context) session_id = _session_id_from_context(context)
lf = get_langfuse() lf = get_langfuse()
llm = get_agent_llm(agent_name) llm = get_agent_llm(agent_name)
tools = _all_tools_for_user(user_id, trace_id) tools = _all_tools_for_user(user_id, trace_id)
model_context = _context_for_model(context)
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id) logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
llm_with_tools = llm.bind_tools(tools) llm_with_tools = llm.bind_tools(tools)
_buffered = session_buffer.get(user_id, session_id) if session_id else None
history_messages = _buffered if _buffered is not None else _history_to_messages(conversation_history)
messages: list[Any] = [ messages: list[Any] = [
SystemMessage(content=system_prompt), SystemMessage(content=system_prompt),
HumanMessage( *history_messages,
content=( HumanMessage(content=message),
f"User message:\n{message}\n\n"
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
)
),
] ]
tool_calls_count = 0 tool_calls_count = 0
@@ -699,7 +1116,7 @@ async def _run_single_agent(
_span_ctx = ( _span_ctx = (
lf.start_as_current_observation( lf.start_as_current_observation(
as_type="span", as_type="agent",
name=agent_name, name=agent_name,
metadata={"user_id": user_id, "session_id": trace_id}, metadata={"user_id": user_id, "session_id": trace_id},
input=message, input=message,
@@ -707,6 +1124,7 @@ async def _run_single_agent(
if lf else None if lf else None
) )
_span = _span_ctx.__enter__() if _span_ctx else None _span = _span_ctx.__enter__() if _span_ctx else None
_messages_to_save: list[Any] | None = None
try: try:
for _ in range(max_steps): for _ in range(max_steps):
@@ -739,6 +1157,7 @@ async def _run_single_agent(
) )
if _span: if _span:
_span.update(output=final_text) _span.update(output=final_text)
_messages_to_save = messages[1:] # strip SystemMessage; save full tool history
return final_text return final_text
tool_map = {tool_def.name: tool_def for tool_def in tools} tool_map = {tool_def.name: tool_def for tool_def in tools}
@@ -757,6 +1176,14 @@ async def _run_single_agent(
tool_fn = tool_map.get(call_name) tool_fn = tool_map.get(call_name)
if tool_fn is None: if tool_fn is None:
tool_output = f"Unknown tool: {call_name}" tool_output = f"Unknown tool: {call_name}"
elif lf:
with lf.start_as_current_observation(
as_type="tool",
name=call_name,
input=call_args,
) as tool_obs:
tool_output = await tool_fn.ainvoke(call_args)
tool_obs.update(output=str(tool_output)[:8000])
else: else:
tool_output = await tool_fn.ainvoke(call_args) tool_output = await tool_fn.ainvoke(call_args)
@@ -771,6 +1198,7 @@ async def _run_single_agent(
final = await llm.ainvoke(messages) final = await llm.ainvoke(messages)
final_text = _as_text(final.content) final_text = _as_text(final.content)
messages.append(AIMessage(content=final_text))
logger.info( logger.info(
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1", "deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
trace_id or "-", trace_id or "-",
@@ -780,8 +1208,11 @@ async def _run_single_agent(
) )
if _span: if _span:
_span.update(output=final_text) _span.update(output=final_text)
_messages_to_save = messages[1:]
return final_text return final_text
finally: finally:
if session_id and _messages_to_save is not None:
session_buffer.set(user_id, session_id, _messages_to_save)
clear_tool_result_collector() clear_tool_result_collector()
if _span_ctx: if _span_ctx:
_span_ctx.__exit__(None, None, None) _span_ctx.__exit__(None, None, None)
@@ -800,6 +1231,7 @@ async def _run_single_agent_stream(
langfuse_prompt: Any = None, langfuse_prompt: Any = None,
agent_name: str = "agent", agent_name: str = "agent",
tools: list[Any] | None = None, tools: list[Any] | None = None,
conversation_history: list[dict[str, str]] | None = None,
) -> AsyncGenerator[tuple[str, Any], None]: ) -> AsyncGenerator[tuple[str, Any], None]:
trace_id = _trace_id_from_context(context) trace_id = _trace_id_from_context(context)
session_id = _session_id_from_context(context) session_id = _session_id_from_context(context)
@@ -807,17 +1239,14 @@ async def _run_single_agent_stream(
llm = get_agent_llm(agent_name) llm = get_agent_llm(agent_name)
if tools is None: if tools is None:
tools = _all_tools_for_user(user_id, trace_id) tools = _all_tools_for_user(user_id, trace_id)
model_context = _context_for_model(context)
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id) logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
llm_with_tools = llm.bind_tools(tools) llm_with_tools = llm.bind_tools(tools)
_buffered = session_buffer.get(user_id, session_id) if session_id else None
history_messages = _buffered if _buffered is not None else _history_to_messages(conversation_history)
messages: list[Any] = [ messages: list[Any] = [
SystemMessage(content=system_prompt), SystemMessage(content=system_prompt),
HumanMessage( *history_messages,
content=( HumanMessage(content=message),
f"User message:\n{message}\n\n"
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
)
),
] ]
tool_calls_count = 0 tool_calls_count = 0
@@ -830,7 +1259,7 @@ async def _run_single_agent_stream(
_span_ctx = ( _span_ctx = (
lf.start_as_current_observation( lf.start_as_current_observation(
as_type="span", as_type="agent",
name=f"{agent_name}-stream", name=f"{agent_name}-stream",
metadata={"user_id": user_id, "session_id": trace_id}, metadata={"user_id": user_id, "session_id": trace_id},
input=message, input=message,
@@ -839,6 +1268,7 @@ async def _run_single_agent_stream(
) )
_span = _span_ctx.__enter__() if _span_ctx else None _span = _span_ctx.__enter__() if _span_ctx else None
streamed_text: list[str] = [] streamed_text: list[str] = []
_messages_to_save: list[Any] | None = None
try: try:
for _ in range(max_steps): for _ in range(max_steps):
@@ -858,25 +1288,15 @@ async def _run_single_agent_stream(
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response)) _gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
_gen_ctx.__exit__(None, None, None) _gen_ctx.__exit__(None, None, None)
messages.append(response)
if not response.tool_calls: if not response.tool_calls:
emitted_any = False # Yield the content from the ainvoke response directly — no second LLM call.
async for chunk in llm.astream(messages): # Previously, messages.append(response) was called first, so the re-stream
token = _as_text(getattr(chunk, "content", "")) # received [System, Human, AI] and regenerated a response without tools bound.
if token: final_text = _as_text(response.content)
streamed_chars += len(token) if final_text:
streamed_text.append(token) streamed_chars += len(final_text)
emitted_any = True streamed_text.append(final_text)
yield "token", token yield "token", final_text
# Some providers return final text in `response.content` but stream no chunks.
if not emitted_any:
fallback_text = _as_text(response.content)
if fallback_text:
streamed_chars += len(fallback_text)
streamed_text.append(fallback_text)
yield "token", fallback_text
logger.info( logger.info(
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d", "deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
trace_id or "-", trace_id or "-",
@@ -886,8 +1306,11 @@ async def _run_single_agent_stream(
) )
if _span: if _span:
_span.update(output="".join(streamed_text)) _span.update(output="".join(streamed_text))
messages.append(response)
_messages_to_save = messages[1:] # strip SystemMessage
return return
messages.append(response)
tool_map = {tool_def.name: tool_def for tool_def in tools} tool_map = {tool_def.name: tool_def for tool_def in tools}
for call in response.tool_calls: for call in response.tool_calls:
tool_calls_count += 1 tool_calls_count += 1
@@ -904,6 +1327,14 @@ async def _run_single_agent_stream(
tool_fn = tool_map.get(call_name) tool_fn = tool_map.get(call_name)
if tool_fn is None: if tool_fn is None:
tool_output = f"Unknown tool: {call_name}" tool_output = f"Unknown tool: {call_name}"
elif lf:
with lf.start_as_current_observation(
as_type="tool",
name=call_name,
input=call_args,
) as tool_obs:
tool_output = await tool_fn.ainvoke(call_args)
tool_obs.update(output=str(tool_output)[:8000])
else: else:
tool_output = await tool_fn.ainvoke(call_args) tool_output = await tool_fn.ainvoke(call_args)
@@ -916,12 +1347,16 @@ async def _run_single_agent_stream(
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
fallback_chunks: list[str] = []
async for chunk in llm.astream(messages): async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", "")) token = _as_text(getattr(chunk, "content", ""))
if token: if token:
streamed_chars += len(token) streamed_chars += len(token)
streamed_text.append(token) streamed_text.append(token)
fallback_chunks.append(token)
yield "token", token yield "token", token
messages.append(AIMessage(content="".join(fallback_chunks)))
_messages_to_save = messages[1:]
logger.info( logger.info(
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1", "deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
trace_id or "-", trace_id or "-",
@@ -932,6 +1367,8 @@ async def _run_single_agent_stream(
if _span: if _span:
_span.update(output="".join(streamed_text)) _span.update(output="".join(streamed_text))
finally: finally:
if session_id and _messages_to_save is not None:
session_buffer.set(user_id, session_id, _messages_to_save)
clear_tool_result_collector() clear_tool_result_collector()
if _span_ctx: if _span_ctx:
_span_ctx.__exit__(None, None, None) _span_ctx.__exit__(None, None, None)
@@ -942,12 +1379,7 @@ async def _run_single_agent_stream(
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str: async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
prepared_context = await _prepare_context(message, context) prepared_context = await _prepare_context(message, context)
system_prompt, langfuse_prompt = get_prompt_or_fallback( system_prompt, langfuse_prompt = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, prepared_context)
"home_system", _HOME_SYSTEM_PROMPT
)
system_prompt += _relational_memory_injection(context)
system_prompt += _proactive_hints_injection(context)
system_prompt += _language_instruction(context)
response = await _run_single_agent( response = await _run_single_agent(
user_id=user_id, user_id=user_id,
system_prompt=system_prompt, system_prompt=system_prompt,
@@ -955,6 +1387,7 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
context=prepared_context, context=prepared_context,
langfuse_prompt=langfuse_prompt, langfuse_prompt=langfuse_prompt,
agent_name="home-agent", agent_name="home-agent",
conversation_history=context.get("conversation_history"),
) )
return _normalize_tagged_list_lines(response, message) return _normalize_tagged_list_lines(response, message)
@@ -962,12 +1395,7 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]: async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
prepared_context = await _prepare_context(message, context) prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context) domain = await _infer_floating_domain(message, prepared_context)
system_prompt, langfuse_prompt = get_prompt_or_fallback( system_prompt, langfuse_prompt = _build_system_prompt("floating_system", _FLOATING_SYSTEM_PROMPT, prepared_context)
"floating_system", _FLOATING_SYSTEM_PROMPT
)
system_prompt += _relational_memory_injection(context)
system_prompt += _proactive_hints_injection(context)
system_prompt += _language_instruction(context)
response = await _run_single_agent( response = await _run_single_agent(
user_id=user_id, user_id=user_id,
system_prompt=system_prompt, system_prompt=system_prompt,
@@ -975,6 +1403,7 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t
context=prepared_context, context=prepared_context,
langfuse_prompt=langfuse_prompt, langfuse_prompt=langfuse_prompt,
agent_name="floating-agent", agent_name="floating-agent",
conversation_history=context.get("conversation_history"),
) )
sanitized = _strip_floating_markup(response) sanitized = _strip_floating_markup(response)
if not sanitized and response: if not sanitized and response:
@@ -986,14 +1415,26 @@ async def run_home_stream(
user_id: str, user_id: str,
message: str, message: str,
context: dict[str, Any], context: dict[str, Any],
project_id: str | None = None,
) -> AsyncGenerator[tuple[str, Any], None]: ) -> AsyncGenerator[tuple[str, Any], None]:
from app.agents.folder_agent import FOLDER_TOOLS
prepared_context = await _prepare_context(message, context) prepared_context = await _prepare_context(message, context)
system_prompt, langfuse_prompt = get_prompt_or_fallback( system_prompt, langfuse_prompt = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, prepared_context)
"home_system", _HOME_SYSTEM_PROMPT
) manifest_block = ""
system_prompt += _relational_memory_injection(context) if project_id:
system_prompt += _proactive_hints_injection(context) manifest = await _fetch_project_manifest(project_id)
system_prompt += _language_instruction(context) manifest_block = format_folder_manifest(manifest)
if not manifest_block:
# No specific project context — surface all linked folders so the agent
# can answer questions like "tell me about project X" using its files.
manifest_block = await build_brief_multi_project_manifest()
system_prompt = system_prompt + ("\n\n" + manifest_block if manifest_block else "")
trace_id = _trace_id_from_context(prepared_context)
tools = [*_all_tools_for_user(user_id, trace_id), *FOLDER_TOOLS]
text_chunks: list[str] = [] text_chunks: list[str] = []
async for event in _run_single_agent_stream( async for event in _run_single_agent_stream(
user_id=user_id, user_id=user_id,
@@ -1002,6 +1443,8 @@ async def run_home_stream(
context=prepared_context, context=prepared_context,
langfuse_prompt=langfuse_prompt, langfuse_prompt=langfuse_prompt,
agent_name="home-agent", agent_name="home-agent",
tools=tools,
conversation_history=context.get("conversation_history"),
): ):
event_type, data = event event_type, data = event
if event_type != "token": if event_type != "token":
@@ -1023,12 +1466,29 @@ async def run_floating_stream(
domain = await _infer_floating_domain(message, prepared_context) domain = await _infer_floating_domain(message, prepared_context)
yield "floating_domain", domain yield "floating_domain", domain
system_prompt, langfuse_prompt = get_prompt_or_fallback( brief_mode: bool = bool(context.get("brief_mode"))
"floating_system", _FLOATING_SYSTEM_PROMPT briefing_context_text: str = str(context.get("briefing_context") or "").strip()
if brief_mode and briefing_context_text:
# Stage 2: inject briefing as ground truth context.
# Pre-substitute {briefing_context} in the template (handles both Langfuse {{}} and fallback {})
# before compile_prompt sees the remaining standard variables.
template, langfuse_prompt = get_prompt_or_fallback(
"task_brief_followup_system",
_TASK_BRIEF_FOLLOWUP_SYSTEM_PROMPT,
) )
system_prompt += _relational_memory_injection(context) system_prompt = compile_prompt(
system_prompt += _proactive_hints_injection(context) template, langfuse_prompt,
system_prompt += _language_instruction(context) date_context=_datetime_context_injection(prepared_context).strip(),
language_instruction=_language_instruction(prepared_context).strip(),
user_identity=_user_identity_injection(prepared_context).strip(),
relational_memory=_relational_memory_injection(prepared_context).strip(),
proactive_hints=_proactive_hints_injection(prepared_context).strip(),
request_context=_request_context_block(prepared_context),
briefing_context=briefing_context_text,
)
else:
system_prompt, langfuse_prompt = _build_system_prompt("floating_system", _FLOATING_SYSTEM_PROMPT, prepared_context)
sanitizer = _FloatingStreamSanitizer() sanitizer = _FloatingStreamSanitizer()
emitted_sanitized = False emitted_sanitized = False
raw_chunks: list[str] = [] raw_chunks: list[str] = []
@@ -1039,6 +1499,7 @@ async def run_floating_stream(
context=prepared_context, context=prepared_context,
langfuse_prompt=langfuse_prompt, langfuse_prompt=langfuse_prompt,
agent_name="floating-agent", agent_name="floating-agent",
conversation_history=context.get("conversation_history"),
): ):
event_type, data = event event_type, data = event
if event_type != "token": if event_type != "token":
@@ -1061,6 +1522,58 @@ async def run_floating_stream(
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks)) yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
async def run_task_brief_research_stream(
user_id: str,
task_id: str,
context: dict[str, Any],
project_id: str | None = None,
) -> AsyncGenerator[tuple[str, Any], None]:
"""Stage-1 executive assistant: deep research for one task.
Yields ``("token", chunk)`` events like other stream runners.
The final concatenated text may contain a ``<canvas kind="...">...</canvas>`` block
which the WS handler strips and emits as a ``canvas_draft`` mutation.
"""
from app.agents.folder_agent import FOLDER_TOOLS
prepared_context = await _prepare_context(f"task:{task_id}", context)
tools = [*_brief_research_tools(user_id, _trace_id_from_context(prepared_context)), *FOLDER_TOOLS]
# Inject task_id so the agent knows what to look up first.
research_message = (
f"Prepare a briefing dossier for task ID: {task_id}\n"
"Follow the research workflow: read the task, then project, then client, "
"then cross-project relations, then relevant memory. "
"End with a concrete suggested first step. "
"If this is a writing task, include a <canvas kind=\"...\"> draft."
)
system_prompt, langfuse_prompt = _build_system_prompt(
"task_brief_research_system",
_TASK_BRIEF_RESEARCH_SYSTEM_PROMPT,
prepared_context,
)
manifest_block = ""
if project_id:
manifest = await _fetch_project_manifest(project_id)
manifest_block = format_folder_manifest(manifest)
system_prompt = system_prompt + ("\n\n" + manifest_block if manifest_block else "")
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,
message=research_message,
context=prepared_context,
max_steps=12,
langfuse_prompt=langfuse_prompt,
agent_name="task-brief-agent",
tools=tools,
conversation_history=None,
):
yield event
async def update_core_memory(user_id: str, key: str, value: str) -> None: async def update_core_memory(user_id: str, key: str, value: str) -> None:
"""Compatibility helper kept for callers that expect explicit memory update API.""" """Compatibility helper kept for callers that expect explicit memory update API."""
async with async_session() as db: async with async_session() as db:

183
app/core/folder_indexer.py Normal file
View File

@@ -0,0 +1,183 @@
"""Per-file summarisation for project folder integration."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass
from langchain_core.messages import HumanMessage, SystemMessage
from pypdf import PdfReader
from docx import Document as DocxDocument
from app.core.langfuse_client import (
compile_prompt,
extract_usage,
get_langfuse,
get_prompt_or_fallback,
)
from app.core.llm import get_llm
_TEXT_FALLBACK = (
"You are summarising a file for an AI assistant that helps the user manage a project.\n"
"Produce a single sentence (<=30 words, <=200 chars) that captures the file's purpose "
"and most important detail.\nFile extension: {ext}\nFile name: {name}\nContent (truncated if long):\n{content}"
)
_IMAGE_FALLBACK = (
"You are summarising an image attached to a project folder.\n"
"Produce a single sentence (<=30 words, <=200 chars) describing what the image shows "
"and any obvious purpose (logo, screenshot, diagram, photo of a whiteboard, etc.)."
)
_MAX_INPUT_CHARS = 6000
@dataclass
class IndexResult:
summary: str
tokens_used: int
async def _llm_text(messages: list) -> object:
"""Make the LLM call for text summarisation.
Defined as a standalone async function so tests can patch it cleanly
without needing to mock the LLM object itself.
"""
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
return await llm.ainvoke(messages)
async def _llm_vision(messages: list) -> object:
"""Make the LLM call for vision (image) summarisation.
Accepts the message list and returns the response directly, mirroring
the ``_llm_text`` caller pattern so tests can patch it at the module level.
"""
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
return await llm.ainvoke(messages)
async def summarize_image(*, image_b64: str, mime: str, file_name: str | None = None) -> IndexResult:
"""Return a compact summary of an image file using vision.
Parameters
----------
image_b64:
Base64-encoded image bytes.
mime:
MIME type of the image, e.g. ``"image/png"``.
file_name:
Optional file name, attached to the Langfuse trace as input metadata.
"""
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_image", _IMAGE_FALLBACK)
messages = [
SystemMessage(content=template),
HumanMessage(content=[
{"type": "text", "text": "Summarise this image."},
{"type": "image_url", "image_url": {"url": f"data:{mime};base64,{image_b64}"}},
]),
]
lf = get_langfuse()
if lf is not None:
with lf.start_as_current_observation(
as_type="generation",
name="folder-summarize-image",
model="gpt-4o-mini",
prompt=prompt_obj,
input={"file_name": file_name, "mime": mime},
) as gen:
response = await _llm_vision(messages)
usage = extract_usage(response)
gen.update(output=response.content, usage_details=usage)
else:
response = await _llm_vision(messages)
usage = extract_usage(response)
summary = (response.content or "").strip()[:500]
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
async def summarize_text(*, content: str, ext: str, name: str) -> IndexResult:
"""Return a compact summary of a text file.
Parameters
----------
content:
Raw text content of the file (will be truncated to _MAX_INPUT_CHARS).
ext:
File extension including the leading dot, e.g. ``".md"``.
name:
File name, e.g. ``"kickoff.md"``.
"""
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_text", _TEXT_FALLBACK)
truncated = content[:_MAX_INPUT_CHARS]
compiled = compile_prompt(template, prompt_obj, ext=ext, name=name, content=truncated)
messages = [
SystemMessage(content=compiled),
HumanMessage(content="Summarise this file."),
]
lf = get_langfuse()
if lf is not None:
with lf.start_as_current_observation(
as_type="generation",
name="folder-summarize-text",
model="gpt-4o-mini",
prompt=prompt_obj,
input={"file_name": name, "ext": ext, "content_chars": len(truncated)},
) as gen:
response = await _llm_text(messages)
usage = extract_usage(response)
gen.update(output=response.content, usage_details=usage)
else:
response = await _llm_text(messages)
usage = extract_usage(response)
summary = (response.content or "").strip()[:500]
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
def _extract_pdf_text(pdf_b64: str) -> str:
buf = io.BytesIO(base64.b64decode(pdf_b64))
reader = PdfReader(buf)
parts: list[str] = []
for page in reader.pages:
try:
parts.append(page.extract_text() or "")
except Exception:
continue
return "\n".join(parts).strip()
def _extract_docx_text(docx_b64: str) -> str:
buf = io.BytesIO(base64.b64decode(docx_b64))
doc = DocxDocument(buf)
return "\n".join(p.text for p in doc.paragraphs if p.text).strip()
async def summarize_pdf(*, pdf_b64: str, name: str) -> IndexResult:
"""Return a compact summary of a PDF file.
Parameters
----------
pdf_b64:
Base64-encoded PDF bytes.
name:
File name, e.g. ``"report.pdf"``.
"""
text = _extract_pdf_text(pdf_b64)
if not text:
return IndexResult(summary="Could not extract text", tokens_used=0)
return await summarize_text(content=text, ext=".pdf", name=name)
async def summarize_docx(*, docx_b64: str, name: str) -> IndexResult:
"""Return a compact summary of a DOCX file.
Parameters
----------
docx_b64:
Base64-encoded DOCX bytes.
name:
File name, e.g. ``"spec.docx"``.
"""
text = _extract_docx_text(docx_b64)
if not text:
return IndexResult(summary="Could not extract text", tokens_used=0)
return await summarize_text(content=text, ext=".docx", name=name)

View File

@@ -51,6 +51,10 @@ def _api_key_for_model(model: str) -> str | None:
return settings.GOOGLE_API_KEY or None return settings.GOOGLE_API_KEY or None
if model.startswith("cerebras/"): if model.startswith("cerebras/"):
return settings.CEREBRAS_API_KEY or None return settings.CEREBRAS_API_KEY or None
if model.startswith("groq/"):
return settings.GROQ_API_KEY or None
if model.startswith("deepseek/"):
return settings.DEEPSEEK_API_KEY or None
if model.startswith("github_copilot/"): if model.startswith("github_copilot/"):
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM. # GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
# No API key is required; returning None lets LiteLLM handle auth. # No API key is required; returning None lets LiteLLM handle auth.
@@ -103,10 +107,12 @@ _AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL, "unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL, "cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
"brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL, "brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL,
"task-brief-agent": lambda: settings.LLM_MODEL_TASK_BRIEF_AGENT or settings.LLM_MODEL,
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL, "setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini", "memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini", "memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
"memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL, "memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL,
"note-summarizer": lambda: "gpt-4o-mini",
} }

View File

@@ -0,0 +1,51 @@
"""Note summarizer — generates a compact AI summary for a note.
Called fire-and-forget from create_note / update_note tools so the
``notes.ai_summary`` column stays current without blocking the agent loop.
"""
from __future__ import annotations
import logging
from langchain_core.messages import HumanMessage, SystemMessage
from app.core.langfuse_client import get_prompt_or_fallback
from app.core.llm import get_agent_llm
logger = logging.getLogger(__name__)
_FALLBACK_PROMPT = """\
Summarize this note in <=250 characters. Be terse and dense.
Keep proper nouns, dates, decisions, and action items.
Do not start with "This note".
Respond with the summary text only — no intro, no labels.
Title: {title}
Content: {content}"""
_MAX_CONTENT_CHARS = 4000
async def generate_note_summary(title: str, content: str) -> str:
"""Return a <=250-char summary of *title* + *content*.
Uses the Langfuse ``note_summary`` prompt (hot-swappable) with a local
fallback. Truncates *content* to 4000 chars before sending to avoid
token waste on large notes.
"""
template, _ = get_prompt_or_fallback("note_summary", _FALLBACK_PROMPT)
trimmed = content[:_MAX_CONTENT_CHARS]
system_prompt = template.format(title=title, content=trimmed)
try:
llm = get_agent_llm("note-summarizer")
response = await llm.ainvoke([
SystemMessage(content=system_prompt),
HumanMessage(content="Generate the summary."),
])
text = response.content if isinstance(response.content, str) else ""
return text.strip()[:250]
except Exception as exc:
logger.warning("note_summarizer: failed to generate summary: %s", exc)
return ""

View File

@@ -2,11 +2,35 @@
from __future__ import annotations from __future__ import annotations
import re
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Any from typing import Any
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
# Matches <canvas kind="...">...</canvas> blocks (single-line or multiline).
_CANVAS_BLOCK_RE = re.compile(
r'<canvas\s+kind=["\']([^"\']+)["\']>(.*?)</canvas>',
re.DOTALL | re.IGNORECASE,
)
def extract_canvas_block(text: str) -> tuple[str, str | None, str | None]:
"""Strip the first <canvas kind="...">...</canvas> block from *text*.
Returns ``(visible_text, canvas_content, canvas_kind)``.
``canvas_content`` and ``canvas_kind`` are ``None`` when no block is found.
"""
match = _CANVAS_BLOCK_RE.search(text)
if not match:
return text, None, None
canvas_kind = match.group(1).strip()
canvas_content = match.group(2).strip()
visible = text[: match.start()] + text[match.end() :]
visible = visible.strip()
return visible, canvas_content, canvas_kind
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain

View File

@@ -7,10 +7,32 @@ The callback sends a `tool_call` WS frame and awaits the `tool_result`.
from __future__ import annotations from __future__ import annotations
import re
from contextvars import ContextVar from contextvars import ContextVar
from typing import Any, Callable, Coroutine from typing import Any, Callable, Coroutine
from uuid import uuid4 from uuid import uuid4
_SNAKE_TO_CAMEL_RE = re.compile(r"_([a-z])")
def _key_to_camel(key: str) -> str:
return _SNAKE_TO_CAMEL_RE.sub(lambda m: m.group(1).upper(), key)
def _keys_to_camel(obj: Any) -> Any:
"""Recursively convert dict keys from snake_case to camelCase.
Mirrors the JS-side ``toCamelCase`` applied to incoming WS frames in
``adiuvAI/src/main/api/backend-client.ts``. The Electron executor wraps
tool_result payloads in ``toSnakeCase`` before sending; this restores the
camelCase schema property names that the tool code expects to read.
"""
if isinstance(obj, dict):
return {_key_to_camel(k): _keys_to_camel(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_keys_to_camel(v) for v in obj]
return obj
# Holds the execute callback for the current WS session. # Holds the execute callback for the current WS session.
# Set by the chat WS handler before the orchestrator runs; cleared after. # Set by the chat WS handler before the orchestrator runs; cleared after.
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar( _client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
@@ -82,6 +104,7 @@ async def execute_on_client(
payload["limit"] = limit payload["limit"] = limit
result = await callback(payload) result = await callback(payload)
result = _keys_to_camel(result)
collector = _tool_result_collector.get(None) collector = _tool_result_collector.get(None)
if collector is not None: if collector is not None:
collector.append({ collector.append({

View File

@@ -243,6 +243,7 @@ class AgentRunLog(Base):
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running") status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0) items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0) items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
errors: Mapped[list | None] = mapped_column(JSON, nullable=True) errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
started_at: Mapped[datetime] = mapped_column( started_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now() DateTime(timezone=True), nullable=False, server_default=func.now()
@@ -263,6 +264,17 @@ class AgentRunLog(Base):
) )
class MonthlyTokenUsage(Base):
__tablename__ = "monthly_token_usage"
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
)
year_month: Mapped[str] = mapped_column(String(7), primary_key=True) # 'YYYY-MM'
feature: Mapped[str] = mapped_column(String(64), primary_key=True)
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
# ── Memory models ───────────────────────────────────────────────────────────── # ── Memory models ─────────────────────────────────────────────────────────────

View File

@@ -87,6 +87,15 @@ class WsFrameType(str, Enum):
journey_reply = "journey_reply" journey_reply = "journey_reply"
# ── v5 brief frame types ────────────────────────────────────────── # ── v5 brief frame types ──────────────────────────────────────────
brief_request = "brief_request" brief_request = "brief_request"
# ── v6 task brief frame types ─────────────────────────────────────
task_brief_request = "task_brief_request"
# ── v7 folder index frame types ───────────────────────────────────
index_session_start = "index_session_start"
index_file_batch = "index_file_batch"
index_session_cancel = "index_session_cancel"
index_file_result = "index_file_result"
index_session_progress = "index_session_progress"
index_session_done = "index_session_done"
class WsToolCall(BaseModel): class WsToolCall(BaseModel):
@@ -142,6 +151,16 @@ class WsDeviceHello(BaseModel):
# ── WebSocket v3 Frame Models ───────────────────────────────────────── # ── WebSocket v3 Frame Models ─────────────────────────────────────────
class FormatPrefsModel(BaseModel):
"""User display preferences sent by Electron on each request."""
timezone: str = "UTC"
date_format: str = "dd/MM/yyyy"
time_format: str = "24h"
locale: str = "en-US"
now_iso: str = ""
class WsFloatingScope(BaseModel): class WsFloatingScope(BaseModel):
"""Scope for a floating request — narrows the agent to a specific entity.""" """Scope for a floating request — narrows the agent to a specific entity."""
@@ -155,6 +174,7 @@ class WsHomeRequest(BaseModel):
type: Literal[WsFrameType.home_request] = WsFrameType.home_request type: Literal[WsFrameType.home_request] = WsFrameType.home_request
message: str message: str
conversation_history: list[dict[str, Any]] = Field(default_factory=list) conversation_history: list[dict[str, Any]] = Field(default_factory=list)
format_prefs: FormatPrefsModel | None = None
class WsFloatingRequest(BaseModel): class WsFloatingRequest(BaseModel):
@@ -163,6 +183,7 @@ class WsFloatingRequest(BaseModel):
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
message: str message: str
scope: WsFloatingScope scope: WsFloatingScope
format_prefs: FormatPrefsModel | None = None
class WsBriefRequest(BaseModel): class WsBriefRequest(BaseModel):
@@ -173,6 +194,7 @@ class WsBriefRequest(BaseModel):
session_id: str | None = None session_id: str | None = None
mode: Literal["home", "project"] mode: Literal["home", "project"]
project_id: str | None = None project_id: str | None = None
format_prefs: FormatPrefsModel | None = None
class WsStreamStart(BaseModel): class WsStreamStart(BaseModel):
@@ -196,6 +218,7 @@ class WsStreamEnd(BaseModel):
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
request_id: str request_id: str
error: str | None = None error: str | None = None
mutations: list[dict[str, Any]] | None = None
class WsDomain(BaseModel): class WsDomain(BaseModel):

View File

@@ -33,9 +33,11 @@ google-auth-httplib2>=0.2.0
msal>=1.28.0 msal>=1.28.0
cryptography>=42.0.0 cryptography>=42.0.0
pgvector>=0.2.5 pgvector>=0.2.5
langfuse>=2.0.0 langfuse>=3.3.1
beautifulsoup4>=4.12.0 beautifulsoup4>=4.12.0
lxml>=5.0.0 lxml>=5.0.0
PyYAML>=6.0.0 PyYAML>=6.0.0
apscheduler>=3.10.0 apscheduler>=3.10.0
ruff>=0.8.0 ruff>=0.8.0
pypdf>=4.0
python-docx>=1.1

View File

@@ -17,6 +17,8 @@ from jose import jwt
from sqlalchemy import StaticPool, event from sqlalchemy import StaticPool, event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy import select
from app.config.settings import settings from app.config.settings import settings
from app.db import Base, get_session from app.db import Base, get_session
from app.main import app from app.main import app
@@ -134,6 +136,38 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"} return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
# ── Convenience aliases and per-tier user fixtures ────────────────────
@pytest_asyncio.fixture
async def db(db_session: AsyncSession) -> AsyncSession:
"""Alias for db_session — used by folder quota tests."""
return db_session
@pytest_asyncio.fixture
async def test_user_free(db_session: AsyncSession):
"""Return the seeded free-tier User row."""
result = await db_session.execute(
select(User).where(User.id == TEST_USER_IDS["free"])
)
return result.scalar_one()
@pytest_asyncio.fixture
async def test_user_power(db_session: AsyncSession):
"""Return the seeded power-tier User row."""
result = await db_session.execute(
select(User).where(User.id == TEST_USER_IDS["power"])
)
return result.scalar_one()
@pytest.fixture
def auth_headers_free() -> dict[str, str]:
"""Authorization header for the seeded free-tier user."""
return auth_header("free")
# ── CLI options ─────────────────────────────────────────────────────── # ── CLI options ───────────────────────────────────────────────────────
def pytest_addoption(parser): def pytest_addoption(parser):

View File

@@ -10,8 +10,11 @@ import pytest
from langchain_core.messages import AIMessage, ToolMessage from langchain_core.messages import AIMessage, ToolMessage
from app.core.deep_agent import ( from app.core.deep_agent import (
_build_system_prompt,
_datetime_context_injection,
_infer_floating_domain, _infer_floating_domain,
_normalize_tagged_list_lines, _normalize_tagged_list_lines,
_request_context_block,
run_floating, run_floating,
run_floating_stream, run_floating_stream,
run_home, run_home,
@@ -91,8 +94,12 @@ async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_res
"floating_domain", "floating_domain",
{"type": "timeline", "id": "tl-1", "section": None}, {"type": "timeline", "id": "tl-1", "section": None},
) )
assert ("token", "stream-") in events # _run_single_agent_stream uses ainvoke (not astream); the final token is
assert ("token", "ok") in events # the second LLM response which echoes the tool result.
token_events = [e for e in events if e[0] == "token"]
assert token_events, "Expected at least one token event"
combined = "".join(str(e[1]) for e in token_events)
assert "Mock Task" in combined
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -286,3 +293,213 @@ async def test_run_floating_stream_returns_fallback_when_sanitization_would_empt
events.append(event) events.append(event)
assert ("token", "No results found.") in events assert ("token", "No results found.") in events
# ── _datetime_context_injection ────────────────────────────────────────────────
def _fp(tz: str, now_iso: str) -> dict:
return {"timezone": tz, "now_iso": now_iso, "date_format": "dd/MM/yyyy", "time_format": "24h"}
def _parse_ms(block: str, key: str) -> tuple[int, int]:
"""Extract [start, end] from a 'key [start, end]' line in the DATE CONTEXT block."""
import re
m = re.search(rf"^{key}\s+\[(\d+),\s*(\d+)\]", block, re.MULTILINE)
assert m, f"Key '{key}' not found in block:\n{block}"
return int(m.group(1)), int(m.group(2))
def test_datetime_context_injection_europe_rome_late_evening():
"""22:16 CEST on 2026-04-26 — 'tomorrow' must be 2026-04-27 00:00→23:59:59.999 CEST."""
from zoneinfo import ZoneInfo
from datetime import datetime, timezone
block = _datetime_context_injection({"format_prefs": _fp("Europe/Rome", "2026-04-26T20:16:02.155Z")})
assert "DATE CONTEXT" in block
assert "Europe/Rome" in block
tz = ZoneInfo("Europe/Rome")
today_start = int(datetime(2026, 4, 26, tzinfo=tz).timestamp() * 1000)
today_end = int(datetime(2026, 4, 27, tzinfo=tz).timestamp() * 1000) - 1
tomorrow_start = today_end + 1
tomorrow_end = int(datetime(2026, 4, 28, tzinfo=tz).timestamp() * 1000) - 1
t_s, t_e = _parse_ms(block, "today")
assert t_s == today_start
assert t_e == today_end
tm_s, tm_e = _parse_ms(block, "tomorrow")
assert tm_s == tomorrow_start
assert tm_e == tomorrow_end
# Sanity: window is exactly 86 400 000 ms (1 day, CEST has no DST jump on this date)
assert today_end - today_start + 1 == 86_400_000
assert tomorrow_end - tomorrow_start + 1 == 86_400_000
def test_datetime_context_injection_utc():
"""UTC timezone: boundaries are clean UTC midnights."""
from datetime import datetime, timezone
block = _datetime_context_injection({"format_prefs": _fp("UTC", "2026-01-15T10:00:00Z")})
t_s, t_e = _parse_ms(block, "today")
expected_start = int(datetime(2026, 1, 15, tzinfo=timezone.utc).timestamp() * 1000)
assert t_s == expected_start
assert t_e == expected_start + 86_400_000 - 1
def test_datetime_context_injection_dst_spring_forward():
"""Europe/Rome DST spring-forward 2026-03-29: that day is 23h, not 24h."""
from zoneinfo import ZoneInfo
from datetime import datetime
block = _datetime_context_injection({"format_prefs": _fp("Europe/Rome", "2026-03-29T08:00:00Z")})
tz = ZoneInfo("Europe/Rome")
day_start = int(datetime(2026, 3, 29, tzinfo=tz).timestamp() * 1000)
day_end = int(datetime(2026, 3, 30, tzinfo=tz).timestamp() * 1000) - 1
t_s, t_e = _parse_ms(block, "today")
assert t_s == day_start
assert t_e == day_end
assert t_e - t_s + 1 == 23 * 3_600_000 # 23-hour day
def test_datetime_context_injection_dst_fall_back():
"""Europe/Rome DST fall-back 2026-10-25: that day is 25h."""
from zoneinfo import ZoneInfo
from datetime import datetime
block = _datetime_context_injection({"format_prefs": _fp("Europe/Rome", "2026-10-25T08:00:00Z")})
tz = ZoneInfo("Europe/Rome")
day_start = int(datetime(2026, 10, 25, tzinfo=tz).timestamp() * 1000)
day_end = int(datetime(2026, 10, 26, tzinfo=tz).timestamp() * 1000) - 1
t_s, t_e = _parse_ms(block, "today")
assert t_s == day_start
assert t_e == day_end
assert t_e - t_s + 1 == 25 * 3_600_000 # 25-hour day
def test_datetime_context_injection_year_boundary():
"""Dec 31 → Jan 1: last_year, this_year, next_month cross year boundary correctly."""
from zoneinfo import ZoneInfo
from datetime import datetime
block = _datetime_context_injection({"format_prefs": _fp("UTC", "2026-12-31T23:00:00Z")})
tz = ZoneInfo("UTC")
yr_s, yr_e = _parse_ms(block, "this_year")
assert yr_s == int(datetime(2026, 1, 1, tzinfo=tz).timestamp() * 1000)
assert yr_e == int(datetime(2027, 1, 1, tzinfo=tz).timestamp() * 1000) - 1
ly_s, ly_e = _parse_ms(block, "last_year")
assert ly_s == int(datetime(2025, 1, 1, tzinfo=tz).timestamp() * 1000)
assert ly_e == yr_s - 1
nm_s, _ = _parse_ms(block, "next_month")
assert nm_s == int(datetime(2027, 1, 1, tzinfo=tz).timestamp() * 1000)
def test_datetime_context_injection_missing_format_prefs():
assert _datetime_context_injection({}) == ""
assert _datetime_context_injection({"format_prefs": None}) == ""
assert _datetime_context_injection({"format_prefs": "bad"}) == ""
# ── _request_context_block ─────────────────────────────────────────────────────
def test_request_context_block_scope_and_project():
ctx = {"scope": {"type": "task", "id": "t-1"}, "resolved_project_id": "proj-uuid"}
block = _request_context_block(ctx)
assert "scope" in block
assert "resolved_project_id: proj-uuid" in block
def test_request_context_block_empty():
assert _request_context_block({}) == ""
assert _request_context_block({"scope": None}) == ""
# ── _build_system_prompt ───────────────────────────────────────────────────────
def test_build_system_prompt_substitutes_all_slots(monkeypatch):
"""All five slots must appear in the compiled output; no raw placeholder remains."""
# Patch get_prompt_or_fallback to return None prompt_obj so we use fallback .format() path
import app.core.deep_agent as da
monkeypatch.setattr(da, "get_prompt_or_fallback", lambda name, fallback: (fallback, None))
ctx = {
"format_prefs": _fp("Europe/Rome", "2026-04-26T20:16:02.155Z"),
"core_memory": {"language": "it"},
"relational_memory": ["Alice — client"],
"proactive_hints": ["User prefers morning meetings"],
"scope": {"type": "task"},
"resolved_project_id": "proj-1",
}
from app.core.deep_agent import _HOME_SYSTEM_PROMPT
text, _ = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, ctx)
# No unresolved placeholders
assert "{date_context}" not in text
assert "{language_instruction}" not in text
assert "{relational_memory}" not in text
assert "{proactive_hints}" not in text
assert "{request_context}" not in text
# Content was injected
assert "DATE CONTEXT" in text
assert "Italian" in text
assert "Alice" in text
assert "morning meetings" in text
assert "proj-1" in text
def test_build_system_prompt_empty_format_prefs(monkeypatch):
"""Missing format_prefs must not raise — date_context slot renders empty string."""
import app.core.deep_agent as da
monkeypatch.setattr(da, "get_prompt_or_fallback", lambda name, fallback: (fallback, None))
from app.core.deep_agent import _HOME_SYSTEM_PROMPT
text, _ = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, {})
# Prompt renders without error; date section is empty but structure holds
assert "# Date filtering" in text
assert "{date_context}" not in text
def test_human_message_is_bare_message(monkeypatch):
"""After the refactor HumanMessage content must equal the raw user message exactly."""
import app.core.deep_agent as da
from langchain_core.messages import HumanMessage as LCHumanMessage
captured: list[list] = []
class _CaptureLLM:
def bind_tools(self, _):
return self
async def ainvoke(self, messages):
captured.append(list(messages))
return AIMessage(content="risposta")
monkeypatch.setattr(da, "get_prompt_or_fallback", lambda n, f: (f, None))
monkeypatch.setattr(da, "get_agent_llm", lambda _: _CaptureLLM())
monkeypatch.setattr(da, "_all_tools_for_user", lambda *_: [])
monkeypatch.setattr(da, "get_langfuse", lambda: None)
monkeypatch.setattr(da, "set_tool_result_collector", lambda _: None)
monkeypatch.setattr(da, "clear_tool_result_collector", lambda: None)
import asyncio
async def _run():
chunks = []
ctx = {"format_prefs": _fp("UTC", "2026-04-27T10:00:00Z")}
async for ev in da.run_home_stream("u1", "Cosa devo fare domani?", ctx):
chunks.append(ev)
asyncio.get_event_loop().run_until_complete(_run())
assert captured, "LLM was never called"
messages = captured[0]
human = next(m for m in messages if isinstance(m, LCHumanMessage))
assert human.content == "Cosa devo fare domani?"
assert "Context:" not in human.content

View File

@@ -0,0 +1,139 @@
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.agents.folder_agent import (
read_project_folder_file,
search_project_folder_file,
)
pytestmark = pytest.mark.asyncio
async def test_happy_path():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "file body", "kind": "text", "totalSize": 9}),
):
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "docs/x.md"})
assert "file body" in out
assert "kind=text" in out
async def test_traversal_rejected():
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "../../etc/passwd"})
assert out == "Access denied"
async def test_absolute_path_rejected():
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "C:\\Windows\\foo"})
assert out == "Access denied"
async def test_missing_file():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "", "kind": "missing", "totalSize": 0}),
):
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "ghost.md"})
assert "not found" in out.lower()
async def test_pagination_signals_more_available():
# Electron returned the first slice, totalSize larger than slice length.
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "first chunk", "kind": "text", "totalSize": 1000}),
):
out = await read_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "big.txt",
"offset": 0,
"length": 11,
})
assert "first chunk" in out
assert "More content available" in out
assert "offset=11" in out
async def test_pdf_extracted_then_sliced(monkeypatch):
from app.agents import folder_agent
monkeypatch.setattr(folder_agent, "_extract_pdf_text", lambda b: "ABC " * 100)
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "JVBERi0xLg==", "kind": "pdf", "totalSize": 12}),
):
out = await read_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "doc.pdf",
"offset": 0,
"length": 8,
})
assert "kind=pdf" in out
assert "ABC ABC " in out
assert "More content available" in out
async def test_image_returns_placeholder():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "iVBORw0K", "kind": "image", "totalSize": 1024}),
):
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "logo.png"})
assert "image" in out.lower()
async def test_search_finds_match_with_context():
body = "alpha\nbeta\nthe needle is here\ngamma\ndelta"
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": body, "kind": "text", "totalSize": len(body)}),
):
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "log.txt",
"query": "needle",
"context_lines": 1,
})
assert "needle" in out
assert "matches=1" in out
# Context lines included
assert "beta" in out
assert "gamma" in out
async def test_search_no_match():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "nothing here", "kind": "text", "totalSize": 12}),
):
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "x.txt",
"query": "zzz",
})
assert "No matches" in out
async def test_search_rejects_traversal():
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "../etc/passwd",
"query": "root",
})
assert out == "Access denied"
async def test_search_image_rejected():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "b64data", "kind": "image", "totalSize": 100}),
):
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "logo.png",
"query": "anything",
})
assert "Cannot search" in out

View File

@@ -0,0 +1,83 @@
"""Folder indexer LLM helpers."""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.core.folder_indexer import summarize_text, summarize_image, IndexResult
pytestmark = pytest.mark.asyncio
async def test_summarize_text_returns_summary_and_tokens():
mock_resp = AsyncMock()
mock_resp.content = "Kickoff notes covering scope and deadlines."
mock_resp.usage_metadata = {"input_tokens": 320, "output_tokens": 18, "total_tokens": 338}
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
result = await summarize_text(content="hello world", ext=".md", name="kickoff.md")
assert isinstance(result, IndexResult)
assert result.summary == "Kickoff notes covering scope and deadlines."
assert result.tokens_used == 338
async def test_summarize_text_truncates_summary_at_500_chars():
mock_resp = AsyncMock()
mock_resp.content = "x" * 1000
mock_resp.usage_metadata = {"total_tokens": 100}
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
result = await summarize_text(content="x", ext=".md", name="x.md")
assert len(result.summary) <= 500
async def test_summarize_image_uses_vision_content_blocks():
mock_resp = AsyncMock()
mock_resp.content = "Final logo on white background."
mock_resp.usage_metadata = {"total_tokens": 500}
captured = {}
async def fake_llm_vision(messages):
captured["messages"] = messages
return mock_resp
with patch("app.core.folder_indexer._llm_vision", new=fake_llm_vision):
result = await summarize_image(image_b64="iVBORw0KG", mime="image/png")
assert "Final logo" in result.summary
assert result.tokens_used == 500
# last message contains an image content block
last = captured["messages"][-1]
assert any(
isinstance(p, dict) and p.get("type") == "image_url"
for p in (last.content if isinstance(last.content, list) else [])
)
async def test_summarize_pdf_extracts_then_summarizes(monkeypatch):
# pypdf.PdfReader returns text from pages
from app.core import folder_indexer
class FakePage:
def extract_text(self): return "PDF page content with project info."
class FakeReader:
pages = [FakePage(), FakePage()]
monkeypatch.setattr(folder_indexer, "PdfReader", lambda buf: FakeReader())
mock_resp = AsyncMock(); mock_resp.content = "Project info doc."; mock_resp.usage_metadata = {"total_tokens": 50}
async def fake_llm(messages): return mock_resp
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
result = await folder_indexer.summarize_pdf(pdf_b64="SGVsbG8=", name="doc.pdf")
assert "Project info" in result.summary
assert result.tokens_used == 50
async def test_summarize_docx_extracts_then_summarizes(monkeypatch):
from app.core import folder_indexer
class FakePara:
def __init__(self, t): self.text = t
class FakeDoc:
paragraphs = [FakePara("Heading"), FakePara("Body paragraph one.")]
monkeypatch.setattr(folder_indexer, "DocxDocument", lambda buf: FakeDoc())
mock_resp = AsyncMock(); mock_resp.content = "Heading and body."; mock_resp.usage_metadata = {"total_tokens": 30}
async def fake_llm(messages): return mock_resp
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
result = await folder_indexer.summarize_docx(docx_b64="UEsDBBQ=", name="doc.docx")
assert result.summary == "Heading and body."

View File

@@ -0,0 +1,94 @@
"""Folder quota helpers."""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from sqlalchemy import select
from app.billing.quota import (
check_folder_quota,
add_token_usage,
QuotaExceeded,
)
from app.models import MonthlyTokenUsage
pytestmark = pytest.mark.asyncio
async def test_check_folder_quota_free_rejects_above_file_cap(db, test_user_free):
with pytest.raises(QuotaExceeded) as exc:
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=500, db=db
)
assert exc.value.reason == "max_files"
async def test_check_folder_quota_free_passes_under_cap(db, test_user_free):
# No raise
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=50, db=db
)
async def test_check_folder_quota_rejects_when_monthly_exhausted(db, test_user_free):
ym = datetime.now(timezone.utc).strftime("%Y-%m")
db.add(MonthlyTokenUsage(
user_id=test_user_free.id, year_month=ym, feature="folder_index", tokens_used=100_000
))
await db.commit()
with pytest.raises(QuotaExceeded) as exc:
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=10, db=db
)
assert exc.value.reason == "monthly_tokens"
async def test_check_folder_quota_power_unlimited(db, test_user_power):
await check_folder_quota(
user_id=test_user_power.id, tier="power", estimated_files=999_999, db=db
)
async def test_add_token_usage_atomic_increment(db, test_user_free):
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=1500, db=db)
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=2500, db=db)
ym = datetime.now(timezone.utc).strftime("%Y-%m")
row = (await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == test_user_free.id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == "folder_index",
)
)).scalar_one()
assert row.tokens_used == 4000
async def test_add_token_usage_returns_exhausted_when_over_cap(db, test_user_free):
result = await add_token_usage(
user_id=test_user_free.id, feature="folder_index", tokens=150_000, db=db, cap=100_000
)
assert result.exhausted is True
assert result.tokens_used == 150_000
def test_quota_check_endpoint_rejects(client, auth_headers_free):
res = client.post(
"/api/v1/billing/quota/check",
json={"feature": "folder_index", "estimated_files": 500},
headers=auth_headers_free,
)
assert res.status_code == 402
body = res.json()
assert body["detail"]["reason"] == "max_files"
def test_quota_check_endpoint_passes(client, auth_headers_free):
res = client.post(
"/api/v1/billing/quota/check",
json={"feature": "folder_index", "estimated_files": 50},
headers=auth_headers_free,
)
assert res.status_code == 200
assert res.json() == {"ok": True}

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.core.deep_agent import format_folder_manifest, MANIFEST_TOKEN_BUDGET
pytestmark = pytest.mark.asyncio
def test_format_folder_manifest_basic():
manifest = {
"folderPath": "D:\\Acme",
"lastScannedAt": "2h ago",
"files": [
{"relPath": "briefs/kickoff.md", "kind": "text", "summary": "Kickoff notes; scope and deadlines."},
{"relPath": "logos/logo-v3.png", "kind": "image", "summary": "Final logo on white."},
],
}
out = format_folder_manifest(manifest)
assert "<linked_folder>" in out
assert "/briefs/kickoff.md" in out or "briefs/kickoff.md" in out
assert "[text]" in out
assert "[image]" in out
def test_format_folder_manifest_truncates_past_budget():
files = [
{"relPath": f"f{i}.md", "kind": "text", "summary": "x" * 100, "mtimeMs": i}
for i in range(2000)
]
out = format_folder_manifest({"folderPath": "p", "lastScannedAt": "now", "files": files})
assert "more files omitted" in out
# Rough token check
assert len(out) // 4 < MANIFEST_TOKEN_BUDGET + 200
def test_format_folder_manifest_null_returns_empty():
assert format_folder_manifest(None) == ""
assert format_folder_manifest({"files": []}) == ""
async def test_brief_multi_project_manifest_top_5_per_project():
fake_response = [
{
"projectId": "p1", "projectName": "Acme", "folderPath": "/a",
"lastScannedAt": "now",
"files": [
{"relPath": f"f{i}.md", "kind": "text", "summary": "s", "mtimeMs": i}
for i in range(10)
],
},
{
"projectId": "p2", "projectName": "Beta", "folderPath": "/b",
"lastScannedAt": "now",
"files": [{"relPath": "x.md", "kind": "text", "summary": "s", "mtimeMs": 1}],
},
]
with patch(
"app.core.deep_agent.execute_on_client",
new=AsyncMock(return_value={"projects": fake_response}),
):
from app.core.deep_agent import build_brief_multi_project_manifest
out = await build_brief_multi_project_manifest()
# Project 1 has 10 files, only top 5 by mtimeMs should appear
assert out.count("[p1]") <= 5
# Project 2 has 1 file, must appear
assert "[p2]" in out or "Beta" in out

View File

@@ -0,0 +1,196 @@
"""Tests for WS folder index_session handlers (Task 9).
Tests the three handler functions directly with a minimal fake WebSocket so
no real WS connection or LLM call is made.
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
import pytest_asyncio
from app.api.routes.device_ws import (
_handle_index_session_start,
_handle_index_file_batch,
_handle_index_session_cancel,
_index_sessions,
)
from app.billing.quota import add_token_usage
from app.core.folder_indexer import IndexResult
from app.models import MonthlyTokenUsage
from app.schemas import WsFrameType
from tests.conftest import TEST_USER_IDS
pytestmark = pytest.mark.asyncio
USER_ID = TEST_USER_IDS["free"]
POWER_USER_ID = TEST_USER_IDS["power"]
# ── Fake WebSocket ────────────────────────────────────────────────────
class _FakeWebSocket:
"""Minimal WebSocket stand-in that records send_text calls."""
def __init__(self) -> None:
self.sent: list[dict] = []
async def send_text(self, text: str) -> None:
self.sent.append(json.loads(text))
def sent_types(self) -> list[str]:
return [f["type"] for f in self.sent]
# ── Helpers ───────────────────────────────────────────────────────────
def _make_session_id() -> str:
import uuid
return str(uuid.uuid4())
def _fake_summarize_text_factory(summary: str = "A test summary.", tokens: int = 100):
"""Return an AsyncMock that resolves to a fixed IndexResult."""
async def _fake(**kwargs) -> IndexResult:
return IndexResult(summary=summary, tokens_used=tokens)
return _fake
# ── Fixtures ──────────────────────────────────────────────────────────
@pytest_asyncio.fixture(autouse=True)
async def _clean_sessions():
"""Ensure _index_sessions is empty before and after each test."""
_index_sessions.clear()
yield
_index_sessions.clear()
# ── Tests ─────────────────────────────────────────────────────────────
async def test_index_session_happy_path(db_session):
"""start + batch of 2 text files → 2 index_file_result + 1 progress + 1 done(completed)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
# Register session.
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"projectId": "proj-1",
"totalFiles": 2,
})
# Verify session was registered.
assert session_id in _index_sessions
assert _index_sessions[session_id]["total"] == 2
assert _index_sessions[session_id]["processed"] == 0
# No response frames expected for session_start.
assert ws.sent == []
# Send batch of 2 text files — patch summarize_text so no LLM call needed.
with patch(
"app.api.routes.device_ws._handle_index_file_batch.__globals__",
# We patch the module-level function in folder_indexer instead:
) if False else patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory()):
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
# Wire db_session into the context manager.
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_async_session.return_value = mock_cm
await _handle_index_file_batch(ws, USER_ID, {
"sessionId": session_id,
"files": [
{"relPath": "README.md", "kind": "text", "content": "hello", "ext": ".md"},
{"relPath": "notes.txt", "kind": "text", "content": "world", "ext": ".txt"},
],
})
types = ws.sent_types()
# Expect 2 file results + 1 progress + 1 done(completed).
assert types.count(WsFrameType.index_file_result) == 2
assert types.count(WsFrameType.index_session_progress) == 1
assert types.count(WsFrameType.index_session_done) == 1
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "completed"
progress_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_progress)
assert progress_frame["processed"] == 2
assert progress_frame["total"] == 2
# Verify session cleaned up.
assert session_id not in _index_sessions
async def test_index_session_cancel(db_session):
"""start then cancel → index_session_done(cancelled)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"totalFiles": 5,
})
assert session_id in _index_sessions
await _handle_index_session_cancel(ws, {"sessionId": session_id})
types = ws.sent_types()
assert WsFrameType.index_session_done in types
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "cancelled"
# Session should be cleaned up.
assert session_id not in _index_sessions
async def test_index_session_quota_exceeded(db_session):
"""Pre-fill usage to cap → batch one file → index_session_done(quota_exceeded)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
# Pre-fill monthly token usage to the free-tier cap (100_000).
ym = datetime.now(timezone.utc).strftime("%Y-%m")
db_session.add(MonthlyTokenUsage(
user_id=USER_ID,
year_month=ym,
feature="folder_index",
tokens_used=100_000, # free tier cap exactly
))
await db_session.commit()
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"totalFiles": 1,
})
with patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory(tokens=1)):
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_async_session.return_value = mock_cm
await _handle_index_file_batch(ws, USER_ID, {
"sessionId": session_id,
"files": [
{"relPath": "file.md", "kind": "text", "content": "content", "ext": ".md"},
],
})
types = ws.sent_types()
# Should have 1 file result (success) then done(quota_exceeded).
assert WsFrameType.index_file_result in types
assert WsFrameType.index_session_done in types
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "quota_exceeded"
# Session should be cleaned up.
assert session_id not in _index_sessions