Compare commits
28 Commits
0b5ef48463
...
feat/proje
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc0e258e8c | ||
|
|
12e203e63d | ||
|
|
ffcd7390f0 | ||
|
|
91e880f9d4 | ||
|
|
7d47ca54be | ||
|
|
956fa88853 | ||
|
|
fb2f59ccea | ||
|
|
56dbb7f4cd | ||
|
|
506f517851 | ||
|
|
520c186991 | ||
|
|
582bf27deb | ||
|
|
2aeb453229 | ||
|
|
b7a4edac90 | ||
|
|
822b4cd8b1 | ||
|
|
ab24fc4c91 | ||
|
|
a98e99f7a2 | ||
|
|
a0ff285bcd | ||
|
|
177c1a87dd | ||
|
|
441a4ea05c | ||
|
|
a693a64bf5 | ||
|
|
67562b8092 | ||
|
|
6f4c68b359 | ||
|
|
c20c6d7853 | ||
|
|
6787e690ba | ||
|
|
cb8f56d909 | ||
|
|
2c7cac9e03 | ||
|
|
ea9094f47f | ||
|
|
d5fea95561 |
10
.env.example
10
.env.example
@@ -21,6 +21,8 @@ OPENAI_API_KEY=
|
|||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
GOOGLE_API_KEY=
|
GOOGLE_API_KEY=
|
||||||
CEREBRAS_API_KEY=
|
CEREBRAS_API_KEY=
|
||||||
|
GROQ_API_KEY=
|
||||||
|
DEEPSEEK_API_KEY=
|
||||||
|
|
||||||
# Default model used by any agent that does not have a specific override below.
|
# Default model used by any agent that does not have a specific override below.
|
||||||
LLM_MODEL=gpt-5-mini
|
LLM_MODEL=gpt-5-mini
|
||||||
@@ -50,6 +52,14 @@ LLM_MODEL_UNIFIED_PROCESSOR=
|
|||||||
# Cloud-processor — fetches and processes data from cloud connectors.
|
# Cloud-processor — fetches and processes data from cloud connectors.
|
||||||
LLM_MODEL_CLOUD_PROCESSOR=
|
LLM_MODEL_CLOUD_PROCESSOR=
|
||||||
|
|
||||||
|
# Brief-agent — produces home and project text briefs.
|
||||||
|
# A small model (e.g. gpt-4o-mini) is sufficient.
|
||||||
|
# LLM_MODEL_BRIEF_AGENT=
|
||||||
|
|
||||||
|
# Task-brief-agent — per-task deep research (Stage 1 executive assistant).
|
||||||
|
# Needs tool-use + reasoning; a capable model recommended (e.g. gpt-4o, gemini-2.5-flash).
|
||||||
|
# LLM_MODEL_TASK_BRIEF_AGENT=
|
||||||
|
|
||||||
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
|
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
|
||||||
LLM_MODEL_SETUP_AGENT=
|
LLM_MODEL_SETUP_AGENT=
|
||||||
|
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -28,6 +28,9 @@ tests/fixtures/private*/
|
|||||||
|
|
||||||
# OS
|
# OS
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
|
# Smoke scripts (dev-only, not for CI)
|
||||||
|
scripts/smoke_*.py
|
||||||
Thumbs.db
|
Thumbs.db
|
||||||
|
|
||||||
# Claude Code
|
# Claude Code
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
## DEV
|
||||||
|
Run in DEV with command:
|
||||||
|
```
|
||||||
|
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload --log-config logging.conf
|
||||||
|
```
|
||||||
46
alembic/versions/d6e3f4a5b6c7_folder_index_tables.py
Normal file
46
alembic/versions/d6e3f4a5b6c7_folder_index_tables.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Add token tracking columns for folder integration.
|
||||||
|
|
||||||
|
Revision ID: d6e3f4a5b6c7
|
||||||
|
Revises: 006
|
||||||
|
Create Date: 2026-05-11 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "d6e3f4a5b6c7"
|
||||||
|
down_revision: Union[str, None] = "006"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"agent_run_logs",
|
||||||
|
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"monthly_token_usage",
|
||||||
|
sa.Column("user_id", UUID(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("year_month", sa.String(7), nullable=False),
|
||||||
|
sa.Column("feature", sa.String(64), nullable=False),
|
||||||
|
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
|
||||||
|
sa.PrimaryKeyConstraint("user_id", "year_month", "feature"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
"ix_monthly_token_usage_user_month",
|
||||||
|
"monthly_token_usage",
|
||||||
|
["user_id", "year_month"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_monthly_token_usage_user_month", table_name="monthly_token_usage")
|
||||||
|
op.drop_table("monthly_token_usage")
|
||||||
|
op.drop_column("agent_run_logs", "tokens_used")
|
||||||
52
app/agents/client_agent.py
Normal file
52
app/agents/client_agent.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""Client agent — read-only tools for the clients table."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_clients(search: str = "", limit: int = 20) -> str:
|
||||||
|
"""List clients, optionally filtered by a name/email substring search.
|
||||||
|
|
||||||
|
search: optional substring to match against client name or email.
|
||||||
|
limit: max rows to return (default 20).
|
||||||
|
"""
|
||||||
|
filters: dict[str, Any] = {"limit": limit}
|
||||||
|
if search:
|
||||||
|
filters["search"] = search
|
||||||
|
|
||||||
|
result = await execute_on_client(action="select", table="clients", filters=filters)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No clients found."
|
||||||
|
lines = [
|
||||||
|
f"- {r.get('name', '?')} (id: {r.get('id')}, email: {r.get('email', '')}, "
|
||||||
|
f"company: {r.get('company', '')})"
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return f"Found {len(rows)} client(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_client(id: str) -> str:
|
||||||
|
"""Get full details for one client by UUID.
|
||||||
|
|
||||||
|
id: the client's UUID.
|
||||||
|
"""
|
||||||
|
if not id:
|
||||||
|
return "Client id is required."
|
||||||
|
|
||||||
|
result = await execute_on_client(action="get", table="clients", data={"id": id})
|
||||||
|
row = result.get("row") or result.get("rows", [None])[0] if result else None
|
||||||
|
if not row:
|
||||||
|
return f"Client '{id}' not found."
|
||||||
|
return f"Client details:\n{json.dumps(row, ensure_ascii=False, indent=2)}"
|
||||||
|
|
||||||
|
|
||||||
|
CLIENT_TOOLS: list[Any] = [list_clients, get_client]
|
||||||
168
app/agents/folder_agent.py
Normal file
168
app/agents/folder_agent.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""Scoped file-read and search tools for the project folder feature."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.folder_indexer import _extract_docx_text, _extract_pdf_text
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
# Cap returned slice size to keep tool output under control.
|
||||||
|
_MAX_RETURN_CHARS = 50_000
|
||||||
|
_MAX_SEARCH_MATCHES = 20
|
||||||
|
|
||||||
|
|
||||||
|
def _is_unsafe_path(rel: str) -> bool:
|
||||||
|
if not rel:
|
||||||
|
return True
|
||||||
|
norm = rel.replace("\\", "/")
|
||||||
|
if norm.startswith("/"):
|
||||||
|
return True
|
||||||
|
# Windows drive letter
|
||||||
|
if len(rel) >= 2 and rel[1] == ":":
|
||||||
|
return True
|
||||||
|
parts = norm.split("/")
|
||||||
|
return ".." in parts
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_file(project_id: str, relative_path: str, offset: int, length: int) -> dict:
|
||||||
|
"""Return the raw Electron tool_result dict for a file read."""
|
||||||
|
return await execute_on_client(
|
||||||
|
action="read_project_folder_file",
|
||||||
|
data={
|
||||||
|
"projectId": project_id,
|
||||||
|
"relativePath": relative_path,
|
||||||
|
"offset": offset,
|
||||||
|
"length": length,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode(result: dict) -> tuple[str, str, int]:
|
||||||
|
"""Decode a tool_result into (text, kind, total_size). For pdf/docx,
|
||||||
|
extracts text from base64. For images, returns a placeholder string.
|
||||||
|
For text, content is already a sliced utf-8 string.
|
||||||
|
"""
|
||||||
|
kind = result.get("kind", "text")
|
||||||
|
content = result.get("content", "") or ""
|
||||||
|
total = int(result.get("totalSize", 0) or 0)
|
||||||
|
if kind == "image":
|
||||||
|
return ("[Image file — cannot be navigated as text. See manifest summary.]", kind, total)
|
||||||
|
if kind == "pdf":
|
||||||
|
return (_extract_pdf_text(content), kind, total)
|
||||||
|
if kind == "docx":
|
||||||
|
return (_extract_docx_text(content), kind, total)
|
||||||
|
return (content, kind, total)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def read_project_folder_file(
|
||||||
|
project_id: str,
|
||||||
|
relative_path: str,
|
||||||
|
offset: int = 0,
|
||||||
|
length: int = _MAX_RETURN_CHARS,
|
||||||
|
) -> str:
|
||||||
|
"""Read a slice of a file inside the project's linked folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: project ID.
|
||||||
|
relative_path: path relative to the linked folder root.
|
||||||
|
offset: char offset to start reading from (0 = beginning).
|
||||||
|
length: max chars to return. Default 50000. Use smaller values to save tokens.
|
||||||
|
|
||||||
|
Returns text content slice with a header showing position. Header tells you
|
||||||
|
when more content is available; call again with the suggested next offset.
|
||||||
|
|
||||||
|
For PDF / DOCX files the backend extracts text first, then applies offset/length
|
||||||
|
on the extracted text. For images returns a placeholder; navigate with the
|
||||||
|
manifest summary instead.
|
||||||
|
"""
|
||||||
|
if _is_unsafe_path(relative_path):
|
||||||
|
return "Access denied"
|
||||||
|
|
||||||
|
result = await _fetch_file(project_id, relative_path, offset, length)
|
||||||
|
text, kind, total_size = _decode(result)
|
||||||
|
|
||||||
|
if not text and kind in ("missing", "error"):
|
||||||
|
return f"File not found or unreadable: {relative_path}"
|
||||||
|
|
||||||
|
if kind in ("pdf", "docx"):
|
||||||
|
# Backend extracted full text — apply offset/length on chars.
|
||||||
|
sliced = text[offset:offset + length]
|
||||||
|
slice_end = min(offset + length, len(text))
|
||||||
|
header = (
|
||||||
|
f"[file={relative_path} kind={kind} offset={offset} end={slice_end} "
|
||||||
|
f"totalChars={len(text)}]"
|
||||||
|
)
|
||||||
|
if slice_end < len(text):
|
||||||
|
header += f"\n[More content available — call again with offset={slice_end}.]"
|
||||||
|
return header + "\n" + sliced
|
||||||
|
|
||||||
|
if kind == "text":
|
||||||
|
slice_end = offset + len(text)
|
||||||
|
header = (
|
||||||
|
f"[file={relative_path} kind=text offset={offset} end={slice_end} "
|
||||||
|
f"totalBytes={total_size}]"
|
||||||
|
)
|
||||||
|
if slice_end < total_size:
|
||||||
|
header += f"\n[More content available — call again with offset={slice_end}.]"
|
||||||
|
return header + "\n" + text
|
||||||
|
|
||||||
|
# image or unknown
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def search_project_folder_file(
|
||||||
|
project_id: str,
|
||||||
|
relative_path: str,
|
||||||
|
query: str,
|
||||||
|
context_lines: int = 3,
|
||||||
|
) -> str:
|
||||||
|
"""Search a project folder file for a query string (case-insensitive substring).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: project ID.
|
||||||
|
relative_path: path relative to the linked folder root.
|
||||||
|
query: text to search for.
|
||||||
|
context_lines: number of lines of context around each match (default 3).
|
||||||
|
|
||||||
|
Returns matching line ranges with surrounding context and 1-based line numbers.
|
||||||
|
Capped at 20 matches; if more exist the header shows the total.
|
||||||
|
|
||||||
|
Works on text, code, markdown, PDF (extracted), and DOCX (extracted).
|
||||||
|
Images and binary files are not searchable.
|
||||||
|
"""
|
||||||
|
if _is_unsafe_path(relative_path):
|
||||||
|
return "Access denied"
|
||||||
|
if not query:
|
||||||
|
return "Empty query."
|
||||||
|
|
||||||
|
# For text we still need full file; pass length=very large.
|
||||||
|
result = await _fetch_file(project_id, relative_path, offset=0, length=10_000_000)
|
||||||
|
text, kind, _ = _decode(result)
|
||||||
|
|
||||||
|
if not text and kind in ("missing", "error"):
|
||||||
|
return f"File not found or unreadable: {relative_path}"
|
||||||
|
if kind == "image":
|
||||||
|
return "Cannot search inside images."
|
||||||
|
|
||||||
|
lines = text.splitlines()
|
||||||
|
q = query.lower()
|
||||||
|
matches = [i for i, line in enumerate(lines) if q in line.lower()]
|
||||||
|
if not matches:
|
||||||
|
return f"No matches for '{query}' in {relative_path}."
|
||||||
|
|
||||||
|
shown = matches[:_MAX_SEARCH_MATCHES]
|
||||||
|
snippets: list[str] = []
|
||||||
|
for i in shown:
|
||||||
|
start = max(0, i - context_lines)
|
||||||
|
end = min(len(lines), i + context_lines + 1)
|
||||||
|
block = "\n".join(f"{n + 1:5d}: {lines[n]}" for n in range(start, end))
|
||||||
|
snippets.append(block)
|
||||||
|
|
||||||
|
header = f"[file={relative_path} matches={len(matches)} showing={len(shown)} query='{query}']"
|
||||||
|
body = "\n---\n".join(snippets)
|
||||||
|
return header + "\n" + body
|
||||||
|
|
||||||
|
|
||||||
|
FOLDER_TOOLS = [read_project_folder_file, search_project_folder_file]
|
||||||
@@ -1,13 +1,14 @@
|
|||||||
"""Note agent — Markdown note management (list, get, create, update, delete)."""
|
"""Note agent — Markdown note management (list, get, create, update, propose edit)."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.llm import embed
|
from app.core.note_summarizer import generate_note_summary
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
_UUID_RE = re.compile(
|
||||||
@@ -19,9 +20,21 @@ def _is_uuid(value: str) -> bool:
|
|||||||
return bool(_UUID_RE.match(value))
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
|
||||||
|
def _fmt_summary(row: dict) -> str:
|
||||||
|
summary = (row.get("aiSummary") or row.get("ai_summary") or "").strip()
|
||||||
|
if summary:
|
||||||
|
return f" — {summary}"
|
||||||
|
snippet = (row.get("content") or "")[:120].replace("\n", " ").strip()
|
||||||
|
return f" — {snippet}" if snippet else ""
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_notes(project_id: str = "") -> str:
|
async def list_notes(project_id: str = "") -> str:
|
||||||
"""List notes, optionally scoped to a project by project_id."""
|
"""List notes with AI summaries, optionally scoped to a project by project_id.
|
||||||
|
|
||||||
|
Returns id, title, and ai_summary for each note so you can decide which
|
||||||
|
note to read in full with get_note before creating or updating.
|
||||||
|
"""
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
@@ -31,7 +44,7 @@ async def list_notes(project_id: str = "") -> str:
|
|||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
return "No notes found."
|
return "No notes found."
|
||||||
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
|
lines = [f" - [{r['id']}] {r['title']}{_fmt_summary(r)}" for r in rows]
|
||||||
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
@@ -66,14 +79,10 @@ async def create_note(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
# Index the note content in the vector store.
|
note_id: str = row["id"]
|
||||||
vector = await embed(content)
|
# Generate summary asynchronously — fire-and-forget.
|
||||||
await execute_on_client(
|
asyncio.create_task(_refresh_summary(note_id, title, content))
|
||||||
action="vector_upsert",
|
return f"Note created: '{row['title']}' (id: {note_id})."
|
||||||
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
|
|
||||||
vector=vector,
|
|
||||||
)
|
|
||||||
return f"Note created: '{row['title']}' (id: {row['id']})."
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -82,7 +91,8 @@ async def update_note(
|
|||||||
title: str = "",
|
title: str = "",
|
||||||
content: str = "",
|
content: str = "",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update an existing note. Only pass fields that should change.
|
"""Update an existing note directly (no approval required).
|
||||||
|
Use propose_note_edit instead when human review is needed.
|
||||||
note_id: UUID of the note (required)
|
note_id: UUID of the note (required)
|
||||||
If you need to preserve existing content, call get_note first.
|
If you need to preserve existing content, call get_note first.
|
||||||
"""
|
"""
|
||||||
@@ -97,17 +107,63 @@ async def update_note(
|
|||||||
data={"id": note_id, "updates": updates},
|
data={"id": note_id, "updates": updates},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
# Re-index if content changed.
|
|
||||||
if content:
|
if content:
|
||||||
vector = await embed(content)
|
new_title = title or row.get("title", "")
|
||||||
await execute_on_client(
|
asyncio.create_task(_refresh_summary(note_id, new_title, content))
|
||||||
action="vector_upsert",
|
|
||||||
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
|
|
||||||
vector=vector,
|
|
||||||
)
|
|
||||||
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def propose_note_edit(
|
||||||
|
note_id: str,
|
||||||
|
edit_type: str,
|
||||||
|
proposed_content: str,
|
||||||
|
reasoning: str = "",
|
||||||
|
anchor_before: str = "",
|
||||||
|
anchor_text: str = "",
|
||||||
|
agent_id: str = "",
|
||||||
|
run_id: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Propose an AI edit to an existing note, pending human approval.
|
||||||
|
|
||||||
|
Use this instead of update_note when review_required is true.
|
||||||
|
The user will see the proposal highlighted before it is merged.
|
||||||
|
|
||||||
|
note_id: UUID of the target note (required)
|
||||||
|
edit_type: 'append' | 'insert' | 'replace'
|
||||||
|
- append: adds proposed_content at the end of the note
|
||||||
|
- insert: inserts proposed_content immediately after anchor_before text
|
||||||
|
- replace: replaces the first occurrence of anchor_text with proposed_content
|
||||||
|
proposed_content: the new Markdown text to add or substitute (required)
|
||||||
|
reasoning: brief explanation shown to the user (recommended)
|
||||||
|
anchor_before: for 'insert' — the text snippet that precedes the insertion point
|
||||||
|
anchor_text: for 'replace' — the exact text to be replaced
|
||||||
|
agent_id: agent identifier (for traceability)
|
||||||
|
run_id: run identifier (for traceability)
|
||||||
|
"""
|
||||||
|
if edit_type not in ("append", "insert", "replace"):
|
||||||
|
return f"Invalid edit_type '{edit_type}'. Use 'append', 'insert', or 'replace'."
|
||||||
|
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="propose_note_edit",
|
||||||
|
data={
|
||||||
|
"noteId": note_id,
|
||||||
|
"type": edit_type,
|
||||||
|
"proposedContent": proposed_content,
|
||||||
|
"reasoning": reasoning or None,
|
||||||
|
"anchorBefore": anchor_before or None,
|
||||||
|
"anchorText": anchor_text or None,
|
||||||
|
"agentId": agent_id or None,
|
||||||
|
"runId": run_id or None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
edit_id = result.get("id", "?")
|
||||||
|
return (
|
||||||
|
f"Edit proposal created (id: {edit_id}) for note {note_id}. "
|
||||||
|
f"Status: pending user approval."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_note(note_id: str) -> str:
|
async def delete_note(note_id: str) -> str:
|
||||||
"""Delete a note permanently by its UUID."""
|
"""Delete a note permanently by its UUID."""
|
||||||
@@ -115,10 +171,36 @@ async def delete_note(note_id: str) -> str:
|
|||||||
return f"Note {note_id} deleted."
|
return f"Note {note_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
async def _refresh_summary(note_id: str, title: str, content: str) -> None:
|
||||||
|
"""Generate and persist the AI summary for a note. Fire-and-forget."""
|
||||||
|
try:
|
||||||
|
summary = await generate_note_summary(title, content)
|
||||||
|
if summary:
|
||||||
|
await execute_on_client(
|
||||||
|
action="update",
|
||||||
|
table="notes",
|
||||||
|
data={
|
||||||
|
"id": note_id,
|
||||||
|
"updates": {
|
||||||
|
"aiSummary": summary,
|
||||||
|
"aiSummaryUpdatedAt": int(__import__("time").time() * 1000),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # fire-and-forget; errors logged by generate_note_summary
|
||||||
|
|
||||||
|
|
||||||
NOTE_TOOLS: list[Any] = [
|
NOTE_TOOLS: list[Any] = [
|
||||||
list_notes,
|
list_notes,
|
||||||
get_note,
|
get_note,
|
||||||
create_note,
|
create_note,
|
||||||
update_note,
|
update_note,
|
||||||
|
propose_note_edit,
|
||||||
delete_note,
|
delete_note,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
NOTE_READ_TOOLS: list[Any] = [
|
||||||
|
list_notes,
|
||||||
|
get_note,
|
||||||
|
]
|
||||||
|
|||||||
@@ -125,3 +125,9 @@ PROJECT_TOOLS: list[Any] = [
|
|||||||
update_project,
|
update_project,
|
||||||
delete_project,
|
delete_project,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
PROJECT_READ_TOOLS: list[Any] = [
|
||||||
|
list_projects,
|
||||||
|
list_all_projects,
|
||||||
|
get_project,
|
||||||
|
]
|
||||||
|
|||||||
63
app/agents/relations_agent.py
Normal file
63
app/agents/relations_agent.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""Relations agent — read-only tool wrapping MemoryMiddleware.query_relations."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.db import async_session
|
||||||
|
|
||||||
|
# Injected at tool-factory time by _brief_research_tools(); not a module-level global.
|
||||||
|
# Each tool closure captures the user_id bound at factory time.
|
||||||
|
|
||||||
|
|
||||||
|
def make_query_relations_tool(user_id: str, trace_id: str | None = None) -> Any:
|
||||||
|
"""Return a query_relations tool bound to *user_id*."""
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def query_relations(
|
||||||
|
subject_label: str = "",
|
||||||
|
predicate: str = "",
|
||||||
|
object_label: str = "",
|
||||||
|
limit: int = 10,
|
||||||
|
) -> str:
|
||||||
|
"""Query the relational memory graph for entity relationships.
|
||||||
|
|
||||||
|
Returns rows where subject ↔ predicate ↔ object match the given filters.
|
||||||
|
All parameters are optional — omit to retrieve all relations up to limit.
|
||||||
|
|
||||||
|
subject_label: entity label on the left side (e.g. a client name, "Acme Corp").
|
||||||
|
predicate: relationship type (e.g. "mentioned_in", "works_at", "related_to").
|
||||||
|
object_label: entity label on the right side (e.g. a project name, "Website Redesign").
|
||||||
|
limit: max rows to return (default 10).
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.info(
|
||||||
|
"relations_agent: query_relations trace=%s user=%s subject=%r predicate=%r object=%r",
|
||||||
|
trace_id or "-", user_id, subject_label, predicate, object_label,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
rows = await memory.query_relations(
|
||||||
|
user_id=user_id,
|
||||||
|
subject=subject_label or None,
|
||||||
|
predicate=predicate or None,
|
||||||
|
object_=object_label or None,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return "No relational memory entries found for the given filters."
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
f"- {r.subject_label} —[{r.predicate}]→ {r.object_label}"
|
||||||
|
+ (f" (confidence: {r.confidence:.2f})" if r.confidence is not None else "")
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return f"Found {len(rows)} relation(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
return query_relations
|
||||||
@@ -26,32 +26,141 @@ def _is_uuid(value: str) -> bool:
|
|||||||
async def list_tasks(
|
async def list_tasks(
|
||||||
project_id: str = "",
|
project_id: str = "",
|
||||||
status: str = "",
|
status: str = "",
|
||||||
|
priority: str = "",
|
||||||
|
assignee: str = "",
|
||||||
search: str = "",
|
search: str = "",
|
||||||
order_by: str = "",
|
order_by: str = "",
|
||||||
|
order_dir: str = "",
|
||||||
|
due_date_from: int = -1,
|
||||||
|
due_date_to: int = -1,
|
||||||
|
created_at_from: int = -1,
|
||||||
|
created_at_to: int = -1,
|
||||||
|
completed_at_from: int = -1,
|
||||||
|
completed_at_to: int = -1,
|
||||||
|
is_ai_suggested: int = -1,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
"""List tasks with optional filters. Returns up to `limit` results (default 50).
|
||||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
|
||||||
|
project_id: UUID of the project to scope results to.
|
||||||
|
status: filter by status — todo | in_progress | done.
|
||||||
|
priority: filter by priority — high | medium | low.
|
||||||
|
assignee: substring to match against assignee names. OMIT unless the user explicitly
|
||||||
|
names a person or refers to themselves ("my tasks", "assigned to me", "mine").
|
||||||
|
Do NOT default to the current user.
|
||||||
|
search: substring search across title and description.
|
||||||
|
order_by: sort field — dueDate | priority | createdAt | completedAt.
|
||||||
|
order_dir: asc (default) | desc.
|
||||||
|
due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit.
|
||||||
|
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
|
||||||
|
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
||||||
|
is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any.
|
||||||
|
limit: max rows to return (default 50). Use with offset to paginate.
|
||||||
|
offset: skip first N rows (default 0).
|
||||||
|
|
||||||
|
Tip — combine *_from and *_to for a closed range; pass only one for open-ended.
|
||||||
|
Tip — prefer count_tasks for "how many" questions to avoid listing rows.
|
||||||
|
Tip — for natural-language windows ("today", "tomorrow", "this week", "last month", etc.)
|
||||||
|
take due_date_from / due_date_to verbatim from the DATE CONTEXT block in the system prompt;
|
||||||
|
do not compute boundaries from the current UTC instant.
|
||||||
|
"""
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
filters: dict[str, Any] = {
|
||||||
action="select",
|
"projectId": normalized_project_id or None,
|
||||||
table="tasks",
|
"status": status or None,
|
||||||
filters={
|
"priority": priority or None,
|
||||||
"projectId": normalized_project_id or None,
|
"search": search or None,
|
||||||
"status": status or None,
|
"orderBy": order_by or None,
|
||||||
"search": search or None,
|
"orderDir": order_dir or None,
|
||||||
"orderBy": order_by or None,
|
"limit": limit,
|
||||||
},
|
"offset": offset,
|
||||||
)
|
}
|
||||||
|
if assignee:
|
||||||
|
filters["assignee"] = assignee
|
||||||
|
if due_date_from != -1:
|
||||||
|
filters["dueDateFrom"] = due_date_from
|
||||||
|
if due_date_to != -1:
|
||||||
|
filters["dueDateTo"] = due_date_to
|
||||||
|
if created_at_from != -1:
|
||||||
|
filters["createdAtFrom"] = created_at_from
|
||||||
|
if created_at_to != -1:
|
||||||
|
filters["createdAtTo"] = created_at_to
|
||||||
|
if completed_at_from != -1:
|
||||||
|
filters["completedAtFrom"] = completed_at_from
|
||||||
|
if completed_at_to != -1:
|
||||||
|
filters["completedAtTo"] = completed_at_to
|
||||||
|
if is_ai_suggested != -1:
|
||||||
|
filters["isAiSuggested"] = is_ai_suggested
|
||||||
|
|
||||||
|
result = await execute_on_client(action="select", table="tasks", filters=filters)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
return "No tasks found matching the given filters."
|
return "No tasks found matching the given filters."
|
||||||
lines = [
|
lines = [
|
||||||
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
|
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, "
|
||||||
|
f"dueDate: {r.get('dueDate')}, completedAt: {r.get('completedAt')}, "
|
||||||
|
f"projectId: {r.get('projectId')}, id: {r['id']})"
|
||||||
for r in rows
|
for r in rows
|
||||||
]
|
]
|
||||||
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def count_tasks(
|
||||||
|
project_id: str = "",
|
||||||
|
status: str = "",
|
||||||
|
priority: str = "",
|
||||||
|
assignee: str = "",
|
||||||
|
search: str = "",
|
||||||
|
due_date_from: int = -1,
|
||||||
|
due_date_to: int = -1,
|
||||||
|
created_at_from: int = -1,
|
||||||
|
created_at_to: int = -1,
|
||||||
|
completed_at_from: int = -1,
|
||||||
|
completed_at_to: int = -1,
|
||||||
|
is_ai_suggested: int = -1,
|
||||||
|
) -> str:
|
||||||
|
"""Count tasks matching the given filters without returning rows.
|
||||||
|
|
||||||
|
Use this instead of list_tasks for "how many" questions — it is much cheaper.
|
||||||
|
Same filter parameters as list_tasks (no limit/offset/order_by needed).
|
||||||
|
assignee: OMIT unless the user explicitly names a person or refers to themselves
|
||||||
|
("my tasks"). Do NOT default to the current user.
|
||||||
|
due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit.
|
||||||
|
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
|
||||||
|
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
||||||
|
Tip — for natural-language windows take due_date_from / due_date_to from the DATE CONTEXT block;
|
||||||
|
do not compute boundaries from the current UTC instant.
|
||||||
|
"""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
|
filters: dict[str, Any] = {
|
||||||
|
"projectId": normalized_project_id or None,
|
||||||
|
"status": status or None,
|
||||||
|
"priority": priority or None,
|
||||||
|
"search": search or None,
|
||||||
|
}
|
||||||
|
if assignee:
|
||||||
|
filters["assignee"] = assignee
|
||||||
|
if due_date_from != -1:
|
||||||
|
filters["dueDateFrom"] = due_date_from
|
||||||
|
if due_date_to != -1:
|
||||||
|
filters["dueDateTo"] = due_date_to
|
||||||
|
if created_at_from != -1:
|
||||||
|
filters["createdAtFrom"] = created_at_from
|
||||||
|
if created_at_to != -1:
|
||||||
|
filters["createdAtTo"] = created_at_to
|
||||||
|
if completed_at_from != -1:
|
||||||
|
filters["completedAtFrom"] = completed_at_from
|
||||||
|
if completed_at_to != -1:
|
||||||
|
filters["completedAtTo"] = completed_at_to
|
||||||
|
if is_ai_suggested != -1:
|
||||||
|
filters["isAiSuggested"] = is_ai_suggested
|
||||||
|
|
||||||
|
result = await execute_on_client(action="count", table="tasks", filters=filters)
|
||||||
|
return f"Task count: {result.get('count', 0)}"
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_task(
|
async def create_task(
|
||||||
title: str,
|
title: str,
|
||||||
@@ -72,6 +181,8 @@ async def create_task(
|
|||||||
due_date: Unix timestamp in milliseconds; 0 means no due date
|
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||||
project_id: optional UUID of the parent project
|
project_id: optional UUID of the parent project
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
|
||||||
|
completedAt is set automatically when status is 'done'.
|
||||||
"""
|
"""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="insert",
|
action="insert",
|
||||||
@@ -90,7 +201,7 @@ async def create_task(
|
|||||||
row = result["row"]
|
row = result["row"]
|
||||||
return (
|
return (
|
||||||
f"Task created: '{row['title']}' "
|
f"Task created: '{row['title']}' "
|
||||||
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
|
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']}, projectId: {row.get('projectId')})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -108,6 +219,10 @@ async def update_task(
|
|||||||
"""Update fields on an existing task. Only pass fields you want to change.
|
"""Update fields on an existing task. Only pass fields you want to change.
|
||||||
task_id: the task's UUID (required)
|
task_id: the task's UUID (required)
|
||||||
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||||
|
|
||||||
|
completedAt is managed automatically:
|
||||||
|
- setting status to 'done' records the current timestamp
|
||||||
|
- changing status away from 'done' clears completedAt
|
||||||
"""
|
"""
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if title:
|
if title:
|
||||||
@@ -130,7 +245,7 @@ async def update_task(
|
|||||||
data={"id": task_id, "updates": updates},
|
data={"id": task_id, "updates": updates},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
|
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']}, projectId: {row.get('projectId')})"
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -141,21 +256,36 @@ async def delete_task(task_id: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_tasks_due_today() -> str:
|
async def list_tasks_due_today(user_timezone: str = "UTC", include_done: bool = False) -> str:
|
||||||
"""List all tasks whose due date falls on today's date."""
|
"""List all tasks whose due date falls on today's date.
|
||||||
now = datetime.now(tz=timezone.utc)
|
|
||||||
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
|
user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York').
|
||||||
end_ms = start_ms + 86_400_000 - 1 # last ms of today
|
Always pass the user's timezone so 'today' is computed in their local time.
|
||||||
|
include_done: set True to also include already-completed tasks due today (default False).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
tz = ZoneInfo(user_timezone or "UTC")
|
||||||
|
except Exception:
|
||||||
|
tz = timezone.utc
|
||||||
|
now_local = datetime.now(tz=tz)
|
||||||
|
start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz)
|
||||||
|
start_ms = int(start_dt.timestamp() * 1000)
|
||||||
|
end_ms = start_ms + 86_400_000 - 1
|
||||||
|
filters: dict[str, Any] = {"dueDateFrom": start_ms, "dueDateTo": end_ms}
|
||||||
|
if not include_done:
|
||||||
|
filters["status"] = "todo"
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
|
filters=filters,
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
return "No tasks are due today."
|
return "No tasks are due today."
|
||||||
lines = [
|
lines = [
|
||||||
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
|
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, "
|
||||||
|
f"projectId: {r.get('projectId')}, id: {r['id']})"
|
||||||
for r in rows
|
for r in rows
|
||||||
]
|
]
|
||||||
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
|
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
|
||||||
@@ -193,7 +323,6 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
|||||||
)
|
)
|
||||||
row = result.get("row", {})
|
row = result.get("row", {})
|
||||||
row_author = row.get("author", author)
|
row_author = row.get("author", author)
|
||||||
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
|
|
||||||
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
||||||
row_comment_id = row.get("id", "unknown")
|
row_comment_id = row.get("id", "unknown")
|
||||||
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
||||||
@@ -211,6 +340,7 @@ async def delete_task_comment(comment_id: str) -> str:
|
|||||||
|
|
||||||
TASK_TOOLS: list[Any] = [
|
TASK_TOOLS: list[Any] = [
|
||||||
list_tasks,
|
list_tasks,
|
||||||
|
count_tasks,
|
||||||
create_task,
|
create_task,
|
||||||
update_task,
|
update_task,
|
||||||
delete_task,
|
delete_task,
|
||||||
@@ -219,3 +349,10 @@ TASK_TOOLS: list[Any] = [
|
|||||||
add_task_comment,
|
add_task_comment,
|
||||||
delete_task_comment,
|
delete_task_comment,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
TASK_READ_TOOLS: list[Any] = [
|
||||||
|
list_tasks,
|
||||||
|
count_tasks,
|
||||||
|
list_tasks_due_today,
|
||||||
|
list_task_comments,
|
||||||
|
]
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
@@ -19,19 +20,128 @@ def _is_uuid(value: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_timelines(project_id: str = "") -> str:
|
async def list_timelines(
|
||||||
"""List timelines. Provide project_id to scope to a specific project."""
|
project_id: str = "",
|
||||||
|
type: str = "",
|
||||||
|
is_completed: int = -1,
|
||||||
|
is_ai_suggested: int = -1,
|
||||||
|
order_by: str = "",
|
||||||
|
order_dir: str = "",
|
||||||
|
date_from: int = -1,
|
||||||
|
date_to: int = -1,
|
||||||
|
created_at_from: int = -1,
|
||||||
|
created_at_to: int = -1,
|
||||||
|
completed_at_from: int = -1,
|
||||||
|
completed_at_to: int = -1,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""List timeline events (milestones, checkpoints, activities) with optional filters.
|
||||||
|
|
||||||
|
project_id: UUID to scope results to a specific project.
|
||||||
|
type: filter by event type — milestone | checkpoint | activity.
|
||||||
|
is_completed: 0 = incomplete only, 1 = completed only, -1 = any (default).
|
||||||
|
is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any.
|
||||||
|
order_by: sort field — date (default) | createdAt | completedAt.
|
||||||
|
order_dir: asc (default) | desc.
|
||||||
|
date_from / date_to: ms epoch range for the event date. Use -1 to omit.
|
||||||
|
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
|
||||||
|
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
||||||
|
limit: max rows to return (default 50). Use with offset to paginate.
|
||||||
|
offset: skip first N rows (default 0).
|
||||||
|
|
||||||
|
Tip — combine *_from and *_to for a closed range; pass only one for open-ended.
|
||||||
|
Tip — prefer count_timelines for "how many" questions to avoid listing rows.
|
||||||
|
Tip — for natural-language windows ("today", "this week", "last month", etc.)
|
||||||
|
take date_from / date_to verbatim from the DATE CONTEXT block in the system prompt;
|
||||||
|
do not compute boundaries from the current UTC instant.
|
||||||
|
"""
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
filters: dict[str, Any] = {
|
||||||
action="select",
|
"projectId": normalized_project_id or None,
|
||||||
table="timelines",
|
"orderBy": order_by or None,
|
||||||
filters={"projectId": normalized_project_id or None},
|
"orderDir": order_dir or None,
|
||||||
)
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
}
|
||||||
|
if type:
|
||||||
|
filters["type"] = type
|
||||||
|
if is_completed != -1:
|
||||||
|
filters["isCompleted"] = is_completed
|
||||||
|
if is_ai_suggested != -1:
|
||||||
|
filters["isAiSuggested"] = is_ai_suggested
|
||||||
|
if date_from != -1:
|
||||||
|
filters["dateFrom"] = date_from
|
||||||
|
if date_to != -1:
|
||||||
|
filters["dateTo"] = date_to
|
||||||
|
if created_at_from != -1:
|
||||||
|
filters["createdAtFrom"] = created_at_from
|
||||||
|
if created_at_to != -1:
|
||||||
|
filters["createdAtTo"] = created_at_to
|
||||||
|
if completed_at_from != -1:
|
||||||
|
filters["completedAtFrom"] = completed_at_from
|
||||||
|
if completed_at_to != -1:
|
||||||
|
filters["completedAtTo"] = completed_at_to
|
||||||
|
|
||||||
|
result = await execute_on_client(action="select", table="timelines", filters=filters)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
return "No timelines found."
|
return "No timeline events found."
|
||||||
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
|
lines = [
|
||||||
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
|
f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, "
|
||||||
|
f"completed: {bool(r.get('isCompleted'))}, completedAt: {r.get('completedAt')}, "
|
||||||
|
f"projectId: {r.get('projectId')}, id: {r['id']})"
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return f"Found {len(rows)} timeline event(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def count_timelines(
|
||||||
|
project_id: str = "",
|
||||||
|
type: str = "",
|
||||||
|
is_completed: int = -1,
|
||||||
|
is_ai_suggested: int = -1,
|
||||||
|
date_from: int = -1,
|
||||||
|
date_to: int = -1,
|
||||||
|
created_at_from: int = -1,
|
||||||
|
created_at_to: int = -1,
|
||||||
|
completed_at_from: int = -1,
|
||||||
|
completed_at_to: int = -1,
|
||||||
|
) -> str:
|
||||||
|
"""Count timeline events matching the given filters without returning rows.
|
||||||
|
|
||||||
|
Use this instead of list_timelines for "how many" questions — it is much cheaper.
|
||||||
|
Same filter parameters as list_timelines (no limit/offset/order_by needed).
|
||||||
|
|
||||||
|
date_from / date_to: ms epoch range for the event date. Use -1 to omit.
|
||||||
|
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
||||||
|
Tip — for natural-language windows take date_from / date_to from the DATE CONTEXT block;
|
||||||
|
do not compute boundaries from the current UTC instant.
|
||||||
|
"""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
|
filters: dict[str, Any] = {"projectId": normalized_project_id or None}
|
||||||
|
if type:
|
||||||
|
filters["type"] = type
|
||||||
|
if is_completed != -1:
|
||||||
|
filters["isCompleted"] = is_completed
|
||||||
|
if is_ai_suggested != -1:
|
||||||
|
filters["isAiSuggested"] = is_ai_suggested
|
||||||
|
if date_from != -1:
|
||||||
|
filters["dateFrom"] = date_from
|
||||||
|
if date_to != -1:
|
||||||
|
filters["dateTo"] = date_to
|
||||||
|
if created_at_from != -1:
|
||||||
|
filters["createdAtFrom"] = created_at_from
|
||||||
|
if created_at_to != -1:
|
||||||
|
filters["createdAtTo"] = created_at_to
|
||||||
|
if completed_at_from != -1:
|
||||||
|
filters["completedAtFrom"] = completed_at_from
|
||||||
|
if completed_at_to != -1:
|
||||||
|
filters["completedAtTo"] = completed_at_to
|
||||||
|
|
||||||
|
result = await execute_on_client(action="count", table="timelines", filters=filters)
|
||||||
|
return f"Timeline event count: {result.get('count', 0)}"
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -39,13 +149,19 @@ async def create_timeline(
|
|||||||
project_id: str,
|
project_id: str,
|
||||||
title: str,
|
title: str,
|
||||||
date: int,
|
date: int,
|
||||||
|
type: str = "milestone",
|
||||||
|
is_completed: int = 0,
|
||||||
is_ai_suggested: int = 0,
|
is_ai_suggested: int = 0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a project timeline (milestone).
|
"""Create a project timeline event.
|
||||||
project_id: REQUIRED UUID of the parent project
|
project_id: REQUIRED UUID of the parent project
|
||||||
title: descriptive name for the milestone
|
title: descriptive name for the event
|
||||||
date: Unix timestamp in milliseconds
|
date: Unix timestamp in milliseconds for the event date
|
||||||
|
type: milestone (default) | checkpoint | activity
|
||||||
|
is_completed: 1 if already completed, 0 if not (default 0)
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
|
||||||
|
completedAt is set automatically when is_completed is 1.
|
||||||
"""
|
"""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="insert",
|
action="insert",
|
||||||
@@ -54,11 +170,13 @@ async def create_timeline(
|
|||||||
"projectId": project_id,
|
"projectId": project_id,
|
||||||
"title": title,
|
"title": title,
|
||||||
"date": date,
|
"date": date,
|
||||||
|
"type": type,
|
||||||
|
"isCompleted": is_completed,
|
||||||
"isAiSuggested": is_ai_suggested,
|
"isAiSuggested": is_ai_suggested,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
|
return f"Timeline event created: '{row['title']}' (id: {row['id']}, date: {row['date']}, type: {row.get('type')})"
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -66,35 +184,87 @@ async def update_timeline(
|
|||||||
timeline_id: str,
|
timeline_id: str,
|
||||||
title: str = "",
|
title: str = "",
|
||||||
date: int = -1,
|
date: int = -1,
|
||||||
|
is_completed: int = -1,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update a timeline. Only pass fields that should change.
|
"""Update a timeline event. Only pass fields that should change.
|
||||||
timeline_id: UUID of the timeline (required)
|
timeline_id: UUID of the event (required)
|
||||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||||
|
is_completed: 0 = mark incomplete, 1 = mark complete, -1 = unchanged
|
||||||
|
|
||||||
|
completedAt is managed automatically:
|
||||||
|
- setting is_completed to 1 records the current timestamp
|
||||||
|
- setting is_completed to 0 clears completedAt
|
||||||
"""
|
"""
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if title:
|
if title:
|
||||||
updates["title"] = title
|
updates["title"] = title
|
||||||
if date != -1:
|
if date != -1:
|
||||||
updates["date"] = date
|
updates["date"] = date
|
||||||
|
if is_completed != -1:
|
||||||
|
updates["isCompleted"] = is_completed
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="update",
|
action="update",
|
||||||
table="timelines",
|
table="timelines",
|
||||||
data={"id": timeline_id, "updates": updates},
|
data={"id": timeline_id, "updates": updates},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
|
return f"Timeline event updated: '{row['title']}' (id: {row['id']})"
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_timeline(timeline_id: str) -> str:
|
async def delete_timeline(timeline_id: str) -> str:
|
||||||
"""Delete a timeline permanently by its UUID."""
|
"""Delete a timeline event permanently by its UUID."""
|
||||||
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
|
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
|
||||||
return f"Timeline {timeline_id} deleted."
|
return f"Timeline event {timeline_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_timelines_today(user_timezone: str = "UTC", include_completed: bool = True) -> str:
|
||||||
|
"""List all timeline events whose date falls on today.
|
||||||
|
|
||||||
|
user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York').
|
||||||
|
Always pass the user's timezone so 'today' is computed in their local time.
|
||||||
|
include_completed: set False to exclude already-completed events (default True).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
tz = ZoneInfo(user_timezone or "UTC")
|
||||||
|
except Exception:
|
||||||
|
tz = timezone.utc
|
||||||
|
now_local = datetime.now(tz=tz)
|
||||||
|
start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz)
|
||||||
|
start_ms = int(start_dt.timestamp() * 1000)
|
||||||
|
end_ms = start_ms + 86_400_000 - 1
|
||||||
|
filters: dict[str, Any] = {"dateFrom": start_ms, "dateTo": end_ms}
|
||||||
|
if not include_completed:
|
||||||
|
filters["isCompleted"] = 0
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="timelines",
|
||||||
|
filters=filters,
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No timeline events today."
|
||||||
|
lines = [
|
||||||
|
f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, "
|
||||||
|
f"completed: {bool(r.get('isCompleted'))}, projectId: {r.get('projectId')}, id: {r['id']})"
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return f"Timeline events today ({len(rows)}):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
TIMELINE_TOOLS: list[Any] = [
|
TIMELINE_TOOLS: list[Any] = [
|
||||||
list_timelines,
|
list_timelines,
|
||||||
|
count_timelines,
|
||||||
|
list_timelines_today,
|
||||||
create_timeline,
|
create_timeline,
|
||||||
update_timeline,
|
update_timeline,
|
||||||
delete_timeline,
|
delete_timeline,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
TIMELINE_READ_TOOLS: list[Any] = [
|
||||||
|
list_timelines,
|
||||||
|
count_timelines,
|
||||||
|
list_timelines_today,
|
||||||
|
]
|
||||||
|
|||||||
@@ -16,16 +16,17 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from sqlalchemy import func, select
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.billing.tier_manager import FEATURES
|
from app.billing.tier_manager import FEATURES
|
||||||
from app.core.agent_runner import is_agent_running, run_local_agent
|
from app.core.agent_runner import is_agent_running, run_local_agent
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
|
from app.core.note_summarizer import generate_note_summary
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.models import AgentRunLog, LocalAgentConfig
|
from app.models import AgentRunLog, LocalAgentConfig
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
@@ -37,6 +38,8 @@ from app.schemas import (
|
|||||||
UserProfile,
|
UserProfile,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/agents", tags=["agents"])
|
router = APIRouter(prefix="/agents", tags=["agents"])
|
||||||
|
|
||||||
|
|
||||||
@@ -230,3 +233,25 @@ async def trigger_agent_run(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return _to_run_log_response(run_log)
|
return _to_run_log_response(run_log)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Note summary endpoint ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class NoteSummarizeRequest(BaseModel):
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class NoteSummarizeResponse(BaseModel):
|
||||||
|
summary: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/notes/summarize", response_model=NoteSummarizeResponse)
|
||||||
|
async def summarize_note(
|
||||||
|
body: NoteSummarizeRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> NoteSummarizeResponse:
|
||||||
|
"""Generate an AI summary for a note. Used by the Electron backfill on startup."""
|
||||||
|
summary = await generate_note_summary(body.title, body.content)
|
||||||
|
return NoteSummarizeResponse(summary=summary)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, Request, status
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
@@ -96,3 +96,37 @@ async def list_invoices(
|
|||||||
"""
|
"""
|
||||||
invoices = await stripe_service.list_invoices(current_user.id, db)
|
invoices = await stripe_service.list_invoices(current_user.id, db)
|
||||||
return invoices
|
return invoices
|
||||||
|
|
||||||
|
|
||||||
|
# ── Quota check ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
from app.billing.quota import check_folder_quota, QuotaExceeded # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaCheckRequest(BaseModel):
|
||||||
|
feature: str
|
||||||
|
estimated_files: int
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/quota/check")
|
||||||
|
async def quota_check(
|
||||||
|
payload: QuotaCheckRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict:
|
||||||
|
"""Pre-flight folder quota check. 402 if tier limits would be exceeded."""
|
||||||
|
if payload.feature != "folder_index":
|
||||||
|
raise HTTPException(status_code=400, detail="Unknown feature")
|
||||||
|
try:
|
||||||
|
await check_folder_quota(
|
||||||
|
user_id=current_user.id,
|
||||||
|
tier=current_user.tier,
|
||||||
|
estimated_files=payload.estimated_files,
|
||||||
|
db=db,
|
||||||
|
)
|
||||||
|
except QuotaExceeded as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=402,
|
||||||
|
detail={"reason": exc.reason, "message": str(exc)},
|
||||||
|
)
|
||||||
|
return {"ok": True}
|
||||||
|
|||||||
@@ -5,13 +5,19 @@ WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
import uuid
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
|
from app.core.brief_agent import run_home_brief, run_project_brief
|
||||||
from app.core.deep_agent import run_home
|
from app.core.deep_agent import run_home
|
||||||
from app.core.llm import embed
|
from app.core.llm import embed
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.db import async_session
|
||||||
from app.schemas import ChatRequest, UserProfile
|
from app.schemas import ChatRequest, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
@@ -45,6 +51,57 @@ async def chat(
|
|||||||
return JSONResponse(content={"response": response})
|
return JSONResponse(content={"response": response})
|
||||||
|
|
||||||
|
|
||||||
|
class _BriefRequest(BaseModel):
|
||||||
|
mode: Literal["home", "project"]
|
||||||
|
project_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class _BriefResponse(BaseModel):
|
||||||
|
response: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/brief", response_model=_BriefResponse)
|
||||||
|
async def brief(
|
||||||
|
body: _BriefRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> _BriefResponse:
|
||||||
|
"""REST fallback for brief when the device WebSocket is not ready."""
|
||||||
|
if body.mode == "project":
|
||||||
|
if not body.project_id:
|
||||||
|
raise HTTPException(status_code=422, detail="project_id required for project mode")
|
||||||
|
try:
|
||||||
|
uuid.UUID(body.project_id)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=422, detail="project_id must be a valid UUID")
|
||||||
|
|
||||||
|
request_id = str(uuid.uuid4())
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(
|
||||||
|
current_user.id,
|
||||||
|
"",
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
context: dict = {
|
||||||
|
"_debug": {"request_id": request_id, "user_id": current_user.id},
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks: list[str] = []
|
||||||
|
if body.mode == "project":
|
||||||
|
stream = run_project_brief(current_user.id, body.project_id, context) # type: ignore[arg-type]
|
||||||
|
else:
|
||||||
|
stream = run_home_brief(current_user.id, context)
|
||||||
|
|
||||||
|
async for event_type, data in stream:
|
||||||
|
if event_type == "token" and data:
|
||||||
|
chunks.append(str(data))
|
||||||
|
|
||||||
|
return _BriefResponse(response="".join(chunks))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/embed", response_model=_EmbedResponse)
|
@router.post("/embed", response_model=_EmbedResponse)
|
||||||
async def embed_text(
|
async def embed_text(
|
||||||
body: _EmbedRequest,
|
body: _EmbedRequest,
|
||||||
|
|||||||
@@ -42,19 +42,25 @@ from sqlalchemy import update
|
|||||||
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
|
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.core.agent_runner import trigger_pending_runs
|
from app.core.agent_runner import trigger_pending_runs
|
||||||
from app.core.deep_agent import run_floating_stream, run_home_stream
|
from app.core.brief_agent import run_home_brief, run_project_brief
|
||||||
|
from app.core.deep_agent import run_floating_stream, run_home_stream, run_task_brief_research_stream
|
||||||
|
from app.core.output_formatter import extract_canvas_block
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.core.output_formatter import StreamFormatter
|
from app.core.output_formatter import StreamFormatter
|
||||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog
|
from app.models import AgentRunLog
|
||||||
from app.schemas import WsFrameType
|
from app.schemas import WsFrameType, WsStreamEnd
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
||||||
|
|
||||||
|
# ── v7 folder index session state ─────────────────────────────────────
|
||||||
|
# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled }
|
||||||
|
_index_sessions: dict[str, dict] = {}
|
||||||
|
|
||||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||||
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
||||||
|
|
||||||
@@ -158,6 +164,16 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
_handle_floating_request(websocket, user_id, frame)
|
_handle_floating_request(websocket, user_id, frame)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.brief_request:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_brief_request(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.task_brief_request:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_task_brief_request(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
elif frame_type == WsFrameType.journey_start:
|
elif frame_type == WsFrameType.journey_start:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
_handle_journey_start(websocket, user_id, frame)
|
_handle_journey_start(websocket, user_id, frame)
|
||||||
@@ -168,6 +184,19 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
_handle_journey_message(websocket, user_id, frame)
|
_handle_journey_message(websocket, user_id, frame)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.index_session_start:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_index_session_start(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.index_file_batch:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_index_file_batch(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.index_session_cancel:
|
||||||
|
await _handle_index_session_cancel(websocket, frame)
|
||||||
|
|
||||||
elif frame_type == "pong":
|
elif frame_type == "pong":
|
||||||
# Heartbeat ack — nothing to do, connection is alive.
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
pass
|
pass
|
||||||
@@ -199,11 +228,13 @@ async def _handle_home_request(
|
|||||||
request_id = frame.get("request_id") or str(uuid4())
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
|
||||||
logger.info(
|
logger.info(
|
||||||
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
|
"device_ws: home_request_start user=%s req=%s session=%s project=%s msg=%s",
|
||||||
user_id,
|
user_id,
|
||||||
request_id,
|
request_id,
|
||||||
session_id,
|
session_id,
|
||||||
|
project_id,
|
||||||
message[:200],
|
message[:200],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -220,6 +251,7 @@ async def _handle_home_request(
|
|||||||
context: dict = {
|
context: dict = {
|
||||||
"conversation_history": frame.get("conversation_history", []),
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
|
"format_prefs": frame.get("format_prefs"),
|
||||||
**memory_context,
|
**memory_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,7 +259,7 @@ async def _handle_home_request(
|
|||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
try:
|
try:
|
||||||
event_stream = run_home_stream(user_id, message, context)
|
event_stream = run_home_stream(user_id, message, context, project_id=project_id)
|
||||||
formatter = StreamFormatter(request_id=request_id)
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
async for ws_frame in formatter.format(event_stream):
|
async for ws_frame in formatter.format(event_stream):
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
@@ -287,8 +319,10 @@ async def _handle_floating_request(
|
|||||||
)
|
)
|
||||||
|
|
||||||
context: dict = {
|
context: dict = {
|
||||||
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
"scope": scope,
|
"scope": scope,
|
||||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
|
"format_prefs": frame.get("format_prefs"),
|
||||||
**memory_context,
|
**memory_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -325,6 +359,179 @@ async def _handle_floating_request(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_brief_request(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a brief_request frame — streams plain-text brief back on the socket.
|
||||||
|
|
||||||
|
No episode storage — briefs are not conversations.
|
||||||
|
"""
|
||||||
|
import uuid as _uuid
|
||||||
|
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
session_id = frame.get("session_id") or str(uuid4())
|
||||||
|
mode: str = frame.get("mode", "home")
|
||||||
|
project_id: str | None = frame.get("project_id")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"device_ws: brief_request_start user=%s req=%s mode=%s project_id=%s",
|
||||||
|
user_id, request_id, mode, project_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate project_id for project mode before touching LLM.
|
||||||
|
if mode == "project":
|
||||||
|
try:
|
||||||
|
if not project_id:
|
||||||
|
raise ValueError("project_id required for project mode")
|
||||||
|
_uuid.UUID(project_id)
|
||||||
|
except (ValueError, AttributeError) as exc:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: brief_request invalid project_id user=%s req=%s: %s",
|
||||||
|
user_id, request_id, exc,
|
||||||
|
)
|
||||||
|
await websocket.send_text(
|
||||||
|
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Enrich context with memory (no user message — use empty string as probe).
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id,
|
||||||
|
"",
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
context: dict = {
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
|
"format_prefs": frame.get("format_prefs"),
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
try:
|
||||||
|
if mode == "project":
|
||||||
|
event_stream = run_project_brief(user_id, project_id, context) # type: ignore[arg-type]
|
||||||
|
else:
|
||||||
|
event_stream = run_home_brief(user_id, context)
|
||||||
|
|
||||||
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
|
async for ws_frame in formatter.format(event_stream):
|
||||||
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: brief_request failed user=%s req=%s: %s",
|
||||||
|
user_id, request_id, exc,
|
||||||
|
)
|
||||||
|
await websocket.send_text(
|
||||||
|
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"device_ws: brief_request_end user=%s req=%s mode=%s",
|
||||||
|
user_id, request_id, mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── v6 Task Brief Handler ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_task_brief_request(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a task_brief_request frame — Stage-1 executive assistant deep research.
|
||||||
|
|
||||||
|
Streams the briefing markdown back to the client.
|
||||||
|
On stream_end, emits a ``canvas_draft`` mutation if the agent produced one.
|
||||||
|
"""
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
session_id = frame.get("session_id") or str(uuid4())
|
||||||
|
task_id: str = frame.get("task_id") or frame.get("taskId") or ""
|
||||||
|
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"device_ws: task_brief_request_start user=%s req=%s task=%s project=%s [cache_miss]",
|
||||||
|
user_id, request_id, task_id, project_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not task_id:
|
||||||
|
await websocket.send_text(
|
||||||
|
WsStreamEnd(request_id=request_id, error="task_id is required").model_dump_json()
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id,
|
||||||
|
f"task brief: {task_id}",
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
context: dict = {
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
|
"format_prefs": frame.get("format_prefs"),
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
response_chunks: list[str] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
event_stream = run_task_brief_research_stream(user_id, task_id, context, project_id=project_id)
|
||||||
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
|
async for ws_frame in formatter.format(event_stream):
|
||||||
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
elif ws_frame.type == "stream_start":
|
||||||
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
# stream_end is emitted below with mutations — skip formatter's version
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: task_brief_request failed user=%s req=%s task=%s: %s",
|
||||||
|
user_id, request_id, task_id, exc,
|
||||||
|
)
|
||||||
|
await websocket.send_text(
|
||||||
|
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
||||||
|
)
|
||||||
|
return
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
# Extract canvas block then emit stream_end with optional mutations.
|
||||||
|
full_response = "".join(response_chunks)
|
||||||
|
_visible, canvas_content, canvas_kind = extract_canvas_block(full_response)
|
||||||
|
|
||||||
|
mutations: list[dict] = []
|
||||||
|
if canvas_content:
|
||||||
|
mutations.append({
|
||||||
|
"type": "canvas_draft",
|
||||||
|
"content": canvas_content,
|
||||||
|
"kind": canvas_kind,
|
||||||
|
})
|
||||||
|
|
||||||
|
await websocket.send_text(
|
||||||
|
WsStreamEnd(request_id=request_id, mutations=mutations or None).model_dump_json()
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"device_ws: task_brief_request_end user=%s req=%s task=%s response_chars=%d canvas=%s",
|
||||||
|
user_id, request_id, task_id, len(full_response), canvas_kind or "none",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -382,6 +589,174 @@ async def _handle_journey_message(
|
|||||||
clear_client_executor()
|
clear_client_executor()
|
||||||
|
|
||||||
|
|
||||||
|
# ── v7 Folder Index Handlers ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_index_session_start(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Register a new folder index session. No response sent — client is declaring intent."""
|
||||||
|
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||||
|
project_id: str | None = frame.get("projectId") or frame.get("project_id")
|
||||||
|
total: int = int(frame.get("totalFiles") or frame.get("total_files") or 0)
|
||||||
|
|
||||||
|
if not session_id:
|
||||||
|
logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
_index_sessions[session_id] = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"project_id": project_id,
|
||||||
|
"processed": 0,
|
||||||
|
"total": total,
|
||||||
|
"cancelled": False,
|
||||||
|
}
|
||||||
|
logger.info(
|
||||||
|
"device_ws: index_session_start user=%s session=%s project=%s total=%d",
|
||||||
|
user_id, session_id, project_id, total,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_index_session_cancel(
|
||||||
|
websocket: WebSocket,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Mark a session as cancelled and emit index_session_done(cancelled)."""
|
||||||
|
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||||
|
session = _index_sessions.get(session_id)
|
||||||
|
if session:
|
||||||
|
session["cancelled"] = True
|
||||||
|
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_session_done,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"status": "cancelled",
|
||||||
|
}))
|
||||||
|
_index_sessions.pop(session_id, None)
|
||||||
|
logger.info("device_ws: index_session_cancel session=%s", session_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_index_file_batch(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Process a batch of files for an index session, streaming results back."""
|
||||||
|
# Lazy imports to avoid heavy load at module startup.
|
||||||
|
from app.core.folder_indexer import ( # noqa: PLC0415
|
||||||
|
summarize_image,
|
||||||
|
summarize_pdf,
|
||||||
|
summarize_docx,
|
||||||
|
summarize_text,
|
||||||
|
)
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
from app.billing.quota import add_token_usage # noqa: PLC0415
|
||||||
|
|
||||||
|
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||||
|
files: list[dict] = frame.get("files", [])
|
||||||
|
|
||||||
|
session = _index_sessions.get(session_id)
|
||||||
|
if not session or session.get("cancelled"):
|
||||||
|
return
|
||||||
|
|
||||||
|
async with async_session() as db:
|
||||||
|
tier = await tier_manager.get_tier(user_id, db)
|
||||||
|
raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
||||||
|
cap: int | None = None if raw_cap == -1 else raw_cap
|
||||||
|
|
||||||
|
for file_info in files:
|
||||||
|
if session.get("cancelled"):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Electron's toSnakeCase converts payload keys, so accept both forms.
|
||||||
|
rel_path: str = file_info.get("relPath") or file_info.get("rel_path") or ""
|
||||||
|
kind: str = file_info.get("kind") or "text"
|
||||||
|
content: str = file_info.get("content") or ""
|
||||||
|
ext: str = file_info.get("ext") or ""
|
||||||
|
mime: str = file_info.get("mime") or "application/octet-stream"
|
||||||
|
name: str = rel_path.split("/")[-1] or rel_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
if kind == "image":
|
||||||
|
res = await summarize_image(image_b64=content, mime=mime)
|
||||||
|
elif kind == "pdf":
|
||||||
|
res = await summarize_pdf(pdf_b64=content, name=name)
|
||||||
|
elif kind == "docx":
|
||||||
|
res = await summarize_docx(docx_b64=content, name=name)
|
||||||
|
else:
|
||||||
|
res = await summarize_text(content=content, ext=ext, name=name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: index_file_batch summarize failed session=%s path=%s: %s",
|
||||||
|
session_id, rel_path, exc,
|
||||||
|
)
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_file_result,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"relPath": rel_path,
|
||||||
|
"summary": None,
|
||||||
|
"tokensUsed": 0,
|
||||||
|
"error": str(exc),
|
||||||
|
}))
|
||||||
|
session["processed"] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Account for token usage and check cap.
|
||||||
|
usage = await add_token_usage(
|
||||||
|
user_id=user_id,
|
||||||
|
feature="folder_index",
|
||||||
|
tokens=res.tokens_used,
|
||||||
|
db=db,
|
||||||
|
cap=cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_file_result,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"relPath": rel_path,
|
||||||
|
"summary": res.summary,
|
||||||
|
"tokensUsed": res.tokens_used,
|
||||||
|
}))
|
||||||
|
session["processed"] += 1
|
||||||
|
|
||||||
|
if usage.exhausted:
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_session_done,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"status": "quota_exceeded",
|
||||||
|
}))
|
||||||
|
_index_sessions.pop(session_id, None)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: index_session quota_exceeded user=%s session=%s",
|
||||||
|
user_id, session_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# After processing the batch, emit progress.
|
||||||
|
processed = session["processed"]
|
||||||
|
total = session["total"]
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_session_progress,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"processed": processed,
|
||||||
|
"total": total,
|
||||||
|
}))
|
||||||
|
|
||||||
|
if processed >= total:
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_session_done,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"status": "completed",
|
||||||
|
}))
|
||||||
|
_index_sessions.pop(session_id, None)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: index_session_done completed user=%s session=%s processed=%d",
|
||||||
|
user_id, session_id, processed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||||
|
|||||||
139
app/billing/quota.py
Normal file
139
app/billing/quota.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
"""Quota checks and atomic token-usage accounting for folder integration."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.billing.tier_manager import TierManager
|
||||||
|
from app.models import MonthlyTokenUsage
|
||||||
|
from app.schemas import BillingTier
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaExceeded(Exception):
|
||||||
|
"""Raised when a folder operation cannot proceed under the user's tier."""
|
||||||
|
|
||||||
|
def __init__(self, reason: str, message: str) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
self.reason = reason # "max_files" | "monthly_tokens"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenUsageResult:
|
||||||
|
tokens_used: int
|
||||||
|
exhausted: bool
|
||||||
|
|
||||||
|
|
||||||
|
def _current_year_month() -> str:
|
||||||
|
return datetime.now(timezone.utc).strftime("%Y-%m")
|
||||||
|
|
||||||
|
|
||||||
|
_tier_manager = TierManager()
|
||||||
|
|
||||||
|
|
||||||
|
async def check_folder_quota(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
tier: BillingTier,
|
||||||
|
estimated_files: int,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Raise QuotaExceeded if folder_max_files or folder_monthly_tokens
|
||||||
|
would be violated. -1 in either feature means unlimited."""
|
||||||
|
max_files = _tier_manager.get_feature_value(tier, "folder_max_files")
|
||||||
|
if max_files != -1 and estimated_files > max_files:
|
||||||
|
raise QuotaExceeded(
|
||||||
|
"max_files",
|
||||||
|
f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
||||||
|
if cap == -1:
|
||||||
|
return
|
||||||
|
ym = _current_year_month()
|
||||||
|
row = (
|
||||||
|
await db.execute(
|
||||||
|
select(MonthlyTokenUsage).where(
|
||||||
|
MonthlyTokenUsage.user_id == user_id,
|
||||||
|
MonthlyTokenUsage.year_month == ym,
|
||||||
|
MonthlyTokenUsage.feature == "folder_index",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
used = row.tokens_used if row else 0
|
||||||
|
if used >= cap:
|
||||||
|
raise QuotaExceeded(
|
||||||
|
"monthly_tokens",
|
||||||
|
f"Monthly token budget exhausted ({used}/{cap}); resets next month.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def add_token_usage(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
feature: str,
|
||||||
|
tokens: int,
|
||||||
|
db: AsyncSession,
|
||||||
|
cap: int | None = None,
|
||||||
|
) -> TokenUsageResult:
|
||||||
|
"""Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature).
|
||||||
|
|
||||||
|
Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls
|
||||||
|
back to a read-then-write on other engines (e.g. aiosqlite in tests).
|
||||||
|
Returns post-update total and whether cap is exhausted.
|
||||||
|
"""
|
||||||
|
ym = _current_year_month()
|
||||||
|
|
||||||
|
# Detect dialect to choose between native upsert and portable fallback.
|
||||||
|
dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr]
|
||||||
|
|
||||||
|
if dialect_name == "postgresql":
|
||||||
|
# Native atomic upsert — production path.
|
||||||
|
stmt = (
|
||||||
|
pg_insert(MonthlyTokenUsage)
|
||||||
|
.values(
|
||||||
|
user_id=user_id,
|
||||||
|
year_month=ym,
|
||||||
|
feature=feature,
|
||||||
|
tokens_used=tokens,
|
||||||
|
)
|
||||||
|
.on_conflict_do_update(
|
||||||
|
index_elements=["user_id", "year_month", "feature"],
|
||||||
|
set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens},
|
||||||
|
)
|
||||||
|
.returning(MonthlyTokenUsage.tokens_used)
|
||||||
|
)
|
||||||
|
used: int = (await db.execute(stmt)).scalar_one()
|
||||||
|
await db.commit()
|
||||||
|
else:
|
||||||
|
# Portable fallback — used in tests (SQLite) and any non-PG engine.
|
||||||
|
row = (
|
||||||
|
await db.execute(
|
||||||
|
select(MonthlyTokenUsage).where(
|
||||||
|
MonthlyTokenUsage.user_id == user_id,
|
||||||
|
MonthlyTokenUsage.year_month == ym,
|
||||||
|
MonthlyTokenUsage.feature == feature,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
row = MonthlyTokenUsage(
|
||||||
|
user_id=user_id,
|
||||||
|
year_month=ym,
|
||||||
|
feature=feature,
|
||||||
|
tokens_used=tokens,
|
||||||
|
)
|
||||||
|
db.add(row)
|
||||||
|
else:
|
||||||
|
row.tokens_used += tokens
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(row)
|
||||||
|
used = row.tokens_used
|
||||||
|
|
||||||
|
exhausted = cap is not None and cap != -1 and used >= cap
|
||||||
|
return TokenUsageResult(tokens_used=used, exhausted=exhausted)
|
||||||
@@ -29,6 +29,8 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"realtime_extraction": False, # batch queue (Phase 2)
|
"realtime_extraction": False, # batch queue (Phase 2)
|
||||||
"relational_memory": False, # relational tier (Phase 3) — Pro+
|
"relational_memory": False, # relational tier (Phase 3) — Pro+
|
||||||
"proactive_mining": False, # Power+ only (Phase 5)
|
"proactive_mining": False, # Power+ only (Phase 5)
|
||||||
|
"folder_max_files": 200,
|
||||||
|
"folder_monthly_tokens": 100_000,
|
||||||
},
|
},
|
||||||
"pro": {
|
"pro": {
|
||||||
"agents": -1, # unlimited
|
"agents": -1, # unlimited
|
||||||
@@ -41,6 +43,8 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
||||||
"relational_memory": True, # person/project predicates
|
"relational_memory": True, # person/project predicates
|
||||||
"proactive_mining": False, # Power+ only (Phase 5)
|
"proactive_mining": False, # Power+ only (Phase 5)
|
||||||
|
"folder_max_files": 5000,
|
||||||
|
"folder_monthly_tokens": 2_000_000,
|
||||||
},
|
},
|
||||||
"power": {
|
"power": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
@@ -53,6 +57,8 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"realtime_extraction": True,
|
"realtime_extraction": True,
|
||||||
"relational_memory": True, # all predicates incl. custom
|
"relational_memory": True, # all predicates incl. custom
|
||||||
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||||
|
"folder_max_files": -1, # unlimited
|
||||||
|
"folder_monthly_tokens": -1, # unlimited
|
||||||
},
|
},
|
||||||
"team": {
|
"team": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
@@ -65,6 +71,8 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"realtime_extraction": True,
|
"realtime_extraction": True,
|
||||||
"relational_memory": True, # all predicates incl. custom
|
"relational_memory": True, # all predicates incl. custom
|
||||||
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||||
|
"folder_max_files": -1, # unlimited
|
||||||
|
"folder_monthly_tokens": -1, # unlimited
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,6 +131,13 @@ class TierManager:
|
|||||||
)
|
)
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||||
|
|
||||||
|
def get_feature_value(self, tier: BillingTier, feature: str) -> int:
|
||||||
|
"""Return integer feature value for tier. -1 means unlimited."""
|
||||||
|
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||||
|
if not isinstance(value, int):
|
||||||
|
return 0
|
||||||
|
return value
|
||||||
|
|
||||||
# ── Rate limiting ────────────────────────────────────────────────────
|
# ── Rate limiting ────────────────────────────────────────────────────
|
||||||
|
|
||||||
def get_rate_limit(self, tier: BillingTier) -> int:
|
def get_rate_limit(self, tier: BillingTier) -> int:
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ class Settings(BaseSettings):
|
|||||||
ANTHROPIC_API_KEY: str = ""
|
ANTHROPIC_API_KEY: str = ""
|
||||||
GOOGLE_API_KEY: str = ""
|
GOOGLE_API_KEY: str = ""
|
||||||
CEREBRAS_API_KEY: str = ""
|
CEREBRAS_API_KEY: str = ""
|
||||||
|
GROQ_API_KEY: str = ""
|
||||||
|
DEEPSEEK_API_KEY: str = ""
|
||||||
|
|
||||||
LLM_MODEL: str = "gpt-4o"
|
LLM_MODEL: str = "gpt-4o"
|
||||||
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||||
@@ -26,7 +28,9 @@ class Settings(BaseSettings):
|
|||||||
LLM_MODEL_FLOATING_AGENT: str = "" # floating-agent (contextual chat)
|
LLM_MODEL_FLOATING_AGENT: str = "" # floating-agent (contextual chat)
|
||||||
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
|
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
|
||||||
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
||||||
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs)
|
||||||
|
LLM_MODEL_TASK_BRIEF_AGENT: str = "" # task-brief-agent (per-task deep research)
|
||||||
|
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
||||||
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
|
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
|
||||||
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)
|
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)
|
||||||
LLM_MODEL_MEMORY_AUDITOR: str = "" # memory-auditor (Phase 7 weekly audit)
|
LLM_MODEL_MEMORY_AUDITOR: str = "" # memory-auditor (Phase 7 weekly audit)
|
||||||
|
|||||||
@@ -287,7 +287,6 @@ async def _run_agent_with_tools(
|
|||||||
return final_text
|
return final_text
|
||||||
|
|
||||||
for call in response.tool_calls:
|
for call in response.tool_calls:
|
||||||
call_id = str(call.get("id", ""))
|
|
||||||
call_name = str(call.get("name", ""))
|
call_name = str(call.get("name", ""))
|
||||||
call_args = call.get("args", {})
|
call_args = call.get("args", {})
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -659,9 +658,14 @@ async def run_local_agent(
|
|||||||
# ── Phase B: single LLM call ─────────────────────────
|
# ── Phase B: single LLM call ─────────────────────────
|
||||||
extraction_rules = _get_extraction_rules(agent_config, content_type)
|
extraction_rules = _get_extraction_rules(agent_config, content_type)
|
||||||
no_match_behavior = _get_no_match_behavior(agent_config)
|
no_match_behavior = _get_no_match_behavior(agent_config)
|
||||||
global_rules_lines = "\n".join(
|
base_global_rules = list(agent_config.get("global_rules", []))
|
||||||
f"- {r}" for r in agent_config.get("global_rules", [])
|
if "notes" in config.data_types:
|
||||||
)
|
base_global_rules.append(
|
||||||
|
"For notes: when updating an existing note use `propose_note_edit` "
|
||||||
|
"(type=append/insert/replace) so the user can review AI changes. "
|
||||||
|
"Only call `update_note` for complete content replacement without review."
|
||||||
|
)
|
||||||
|
global_rules_lines = "\n".join(f"- {r}" for r in base_global_rules)
|
||||||
metadata_section = _format_metadata(preprocessed.metadata)
|
metadata_section = _format_metadata(preprocessed.metadata)
|
||||||
|
|
||||||
system_prompt = compile_prompt(
|
system_prompt = compile_prompt(
|
||||||
|
|||||||
59
app/core/agent_session_buffer.py
Normal file
59
app/core/agent_session_buffer.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""In-process TTL buffer for per-session LangChain message history.
|
||||||
|
|
||||||
|
Stores the full message list (including AIMessage with tool_calls and ToolMessage)
|
||||||
|
keyed by (user_id, session_id), so agents can reconstruct tool-call context across
|
||||||
|
conversation turns without it being lossy through the wire.
|
||||||
|
|
||||||
|
Single-process only. For multi-worker deployments, replace the _SessionBuffer
|
||||||
|
implementation with one backed by Redis (serialize LangChain messages to dicts via
|
||||||
|
message_to_dict / messages_from_dict from langchain_core.messages).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
|
||||||
|
SESSION_TTL_SECONDS = 1800 # 30-minute idle expiry
|
||||||
|
MAX_MESSAGES_PER_SESSION = 80 # cap to avoid unbounded memory growth
|
||||||
|
|
||||||
|
|
||||||
|
class _SessionBuffer:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._store: dict[tuple[str, str], tuple[float, list[BaseMessage]]] = {}
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def _evict_stale(self) -> None:
|
||||||
|
now = time.monotonic()
|
||||||
|
stale = [k for k, (ts, _) in self._store.items() if now - ts > SESSION_TTL_SECONDS]
|
||||||
|
for k in stale:
|
||||||
|
del self._store[k]
|
||||||
|
|
||||||
|
def get(self, user_id: str, session_id: str) -> list[BaseMessage] | None:
|
||||||
|
key = (user_id, session_id)
|
||||||
|
with self._lock:
|
||||||
|
entry = self._store.get(key)
|
||||||
|
if entry is None:
|
||||||
|
return None
|
||||||
|
ts, msgs = entry
|
||||||
|
if time.monotonic() - ts > SESSION_TTL_SECONDS:
|
||||||
|
del self._store[key]
|
||||||
|
return None
|
||||||
|
self._store[key] = (time.monotonic(), msgs)
|
||||||
|
return list(msgs)
|
||||||
|
|
||||||
|
def set(self, user_id: str, session_id: str, messages: list[BaseMessage]) -> None:
|
||||||
|
key = (user_id, session_id)
|
||||||
|
capped = messages[-MAX_MESSAGES_PER_SESSION:]
|
||||||
|
with self._lock:
|
||||||
|
self._evict_stale()
|
||||||
|
self._store[key] = (time.monotonic(), capped)
|
||||||
|
|
||||||
|
def clear(self, user_id: str, session_id: str) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._store.pop((user_id, session_id), None)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton — same pattern as _pending_states in api/app/api/routes/auth.py
|
||||||
|
session_buffer = _SessionBuffer()
|
||||||
228
app/core/brief_agent.py
Normal file
228
app/core/brief_agent.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
"""Brief agent — produces plain-text home and project status briefs.
|
||||||
|
|
||||||
|
Read-only tool subset only. Never calls _normalize_tagged_list_lines —
|
||||||
|
the brief prompt forbids XML tags, so skipping post-processing is intentional.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from datetime import date
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.agents.note_agent import NOTE_READ_TOOLS
|
||||||
|
from app.agents.project_agent import PROJECT_READ_TOOLS
|
||||||
|
from app.agents.task_agent import TASK_READ_TOOLS
|
||||||
|
from app.agents.timeline_agent import TIMELINE_READ_TOOLS
|
||||||
|
from app.core.deep_agent import (
|
||||||
|
_language_instruction,
|
||||||
|
_proactive_hints_injection,
|
||||||
|
_read_only_memory_tools,
|
||||||
|
_relational_memory_injection,
|
||||||
|
_run_single_agent_stream,
|
||||||
|
_trace_id_from_context,
|
||||||
|
build_brief_multi_project_manifest,
|
||||||
|
)
|
||||||
|
from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback
|
||||||
|
|
||||||
|
_LANGUAGE_NAMES: dict[str, str] = {
|
||||||
|
"en": "English", "it": "Italian", "es": "Spanish",
|
||||||
|
"fr": "French", "de": "German",
|
||||||
|
"english": "English", "italian": "Italian", "italiano": "Italian",
|
||||||
|
"spanish": "Spanish", "español": "Spanish",
|
||||||
|
"french": "French", "français": "French",
|
||||||
|
"german": "German", "deutsch": "German",
|
||||||
|
}
|
||||||
|
|
||||||
|
_HOME_BRIEF_FALLBACK = """\
|
||||||
|
You are the user's personal assistant producing a short daily brief.
|
||||||
|
|
||||||
|
ROLE
|
||||||
|
Act like a calm, attentive secretary writing a stand-up note for your boss.
|
||||||
|
Warm and human, never breezy. Never cheerful filler, never emojis, never
|
||||||
|
"here is your brief" meta-text. The user is opening the app mid-workday and
|
||||||
|
is probably stressed — your job is to lower cognitive load, not add noise.
|
||||||
|
|
||||||
|
TOOLS — always call before writing
|
||||||
|
Pull fresh data every run. Do not invent counts or titles. Use at minimum:
|
||||||
|
- list_tasks_due_today — tasks the user owes today
|
||||||
|
- list_timelines_today — events starting or ending today
|
||||||
|
- list_all_projects — projects currently in progress or at risk
|
||||||
|
- memory_list_blocks / memory_get — personal context about people, clients,
|
||||||
|
payment habits, working preferences
|
||||||
|
If a tool returns nothing, simply omit that topic. Never report zeros.
|
||||||
|
|
||||||
|
WHAT TO INCLUDE
|
||||||
|
1. Tasks due today (title + priority; group the 1-2 most important).
|
||||||
|
2. Timeline events starting or ending today (and anything that starts/ends
|
||||||
|
tomorrow if the user has a very light day).
|
||||||
|
3. Active projects that need a nudge — stalled, blocked, or awaiting input.
|
||||||
|
4. Memory-aware colour where it sharpens the brief. Examples:
|
||||||
|
- "Client Rossi tends to pay late — the Acme invoice is 6 days out."
|
||||||
|
- "You usually dislike meetings before 10:00 — the call at 09:30 is unusual."
|
||||||
|
Only add a memory line when it changes what the user does. Do not pad.
|
||||||
|
|
||||||
|
WHAT TO OMIT
|
||||||
|
- Zero-counts ("no overdue items", "0 meetings today").
|
||||||
|
- Statistics ("2 active projects, 3 completed tasks").
|
||||||
|
- Headers, titles, greetings, sign-offs, dates, emojis, slang.
|
||||||
|
- Meta-phrases ("here is", "let me know if", "hope this helps").
|
||||||
|
- XML/HTML tags of any kind. Plain prose only.
|
||||||
|
|
||||||
|
LIGHT-DAY CLAUSE
|
||||||
|
If tasks + events + active-project-nudges together produce fewer than two
|
||||||
|
sentences of content, also list 1-2 projects in status on_hold or waiting
|
||||||
|
and ask a single, specific question about them — e.g. "Is the Bianchi
|
||||||
|
redesign still paused, or ready to pick back up?" One question max, grounded
|
||||||
|
in a real project name.
|
||||||
|
|
||||||
|
VOICE
|
||||||
|
- Calm. Concise. Human. Short sentences.
|
||||||
|
- Use **bold** sparingly for task titles, project names, and people's names.
|
||||||
|
- No bullet lists. Flow as 2-4 sentences of prose.
|
||||||
|
|
||||||
|
LENGTH
|
||||||
|
2-4 sentences total. Hard cap 4. If the day is truly empty, one sentence.
|
||||||
|
|
||||||
|
Respond in the user's language ({language}). Today is {today}.\
|
||||||
|
"""
|
||||||
|
|
||||||
|
_PROJECT_BRIEF_FALLBACK = """\
|
||||||
|
You are the project assistant producing a short status brief for ONE project.
|
||||||
|
|
||||||
|
ROLE
|
||||||
|
A senior project manager summarising state-of-play for the owner. Factual,
|
||||||
|
sharp, forward-looking. Never reassuring filler, never emojis.
|
||||||
|
|
||||||
|
SCOPE
|
||||||
|
Work only with project_id = {project_id}. Do not mention or pull data from
|
||||||
|
other projects. Use tools to fetch fresh data:
|
||||||
|
- get_project — current status, dates, description
|
||||||
|
- list_tasks(project_id) — open work, split by status
|
||||||
|
- list_timelines(project_id) — milestones hit, upcoming, overdue
|
||||||
|
- list_notes(project_id) — any recent decisions or blockers
|
||||||
|
- memory_get — relevant context about the client, collaborators, constraints
|
||||||
|
|
||||||
|
STRUCTURE — follow exactly, one short paragraph per section, no headers
|
||||||
|
1. **State.** One sentence: current phase, health (on track / at risk / blocked),
|
||||||
|
and why. Cite the concrete signal (overdue milestone, stalled tasks, recent
|
||||||
|
blocker note).
|
||||||
|
2. **What's moving.** What was completed or progressed recently. Name specific
|
||||||
|
tasks or milestones.
|
||||||
|
3. **Next steps.** The 1-3 most important things the user should do next, in
|
||||||
|
priority order. Be concrete — task name, who owns it, when due if known.
|
||||||
|
If waiting on someone else, name them and what the ask is.
|
||||||
|
4. **Risks / memory-flagged items.** One line max. Only include when there is
|
||||||
|
a real risk or a relevant memory (e.g. late-paying client, tight deadline,
|
||||||
|
scope change). Omit the section entirely if nothing to say.
|
||||||
|
|
||||||
|
WHAT TO OMIT
|
||||||
|
- Zero-counts ("no overdue tasks").
|
||||||
|
- Generic advice ("keep up the good work").
|
||||||
|
- Greetings, headers, bullet lists, emojis, sign-offs, meta-phrases.
|
||||||
|
- XML/HTML tags or bracketed id lists. Plain prose only.
|
||||||
|
|
||||||
|
VOICE
|
||||||
|
- Direct. Factual. No fluff.
|
||||||
|
- Use **bold** sparingly for task titles, milestone names, and the owner's name.
|
||||||
|
- Short sentences. Prefer verbs over nouns ("Client review is blocking release"
|
||||||
|
not "There is a blocker which is the client review").
|
||||||
|
|
||||||
|
LENGTH
|
||||||
|
4-8 sentences total across the 3-4 sections. Hard cap 8.
|
||||||
|
|
||||||
|
Respond in the user's language ({language}). Today is {today}.\
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_language(context: dict[str, Any]) -> str:
|
||||||
|
core = context.get("core_memory") or {}
|
||||||
|
raw = (core.get("language") or "en").strip().lower()
|
||||||
|
return _LANGUAGE_NAMES.get(raw, raw.title()) or "English"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_read_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
|
return [
|
||||||
|
*TASK_READ_TOOLS,
|
||||||
|
*PROJECT_READ_TOOLS,
|
||||||
|
*TIMELINE_READ_TOOLS,
|
||||||
|
*NOTE_READ_TOOLS,
|
||||||
|
*_read_only_memory_tools(user_id, trace_id),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def run_home_brief(
|
||||||
|
user_id: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
"""Stream a plain-text daily home brief.
|
||||||
|
|
||||||
|
Yields (event_type, data) tuples identical to _run_single_agent_stream.
|
||||||
|
Do NOT post-process output through _normalize_tagged_list_lines.
|
||||||
|
"""
|
||||||
|
from app.agents.folder_agent import FOLDER_TOOLS
|
||||||
|
|
||||||
|
trace_id = _trace_id_from_context(context)
|
||||||
|
today = date.today().isoformat()
|
||||||
|
language = _resolve_language(context)
|
||||||
|
|
||||||
|
raw_template, langfuse_prompt = get_prompt_or_fallback("home_brief", _HOME_BRIEF_FALLBACK)
|
||||||
|
system_prompt = compile_prompt(raw_template, langfuse_prompt, language=language, today=today)
|
||||||
|
system_prompt += _relational_memory_injection(context)
|
||||||
|
system_prompt += _proactive_hints_injection(context)
|
||||||
|
system_prompt += _language_instruction(context)
|
||||||
|
if today not in system_prompt:
|
||||||
|
system_prompt += f"\nToday is {today}."
|
||||||
|
|
||||||
|
brief_manifest = await build_brief_multi_project_manifest()
|
||||||
|
system_prompt = system_prompt + ("\n\n" + brief_manifest if brief_manifest else "")
|
||||||
|
|
||||||
|
tools = [*_build_read_tools(user_id, trace_id), *FOLDER_TOOLS]
|
||||||
|
async for event in _run_single_agent_stream(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message="Generate the daily brief.",
|
||||||
|
context=context,
|
||||||
|
langfuse_prompt=langfuse_prompt,
|
||||||
|
agent_name="brief-agent",
|
||||||
|
tools=tools,
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
|
async def run_project_brief(
|
||||||
|
user_id: str,
|
||||||
|
project_id: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
"""Stream a plain-text project status brief for project_id.
|
||||||
|
|
||||||
|
Yields (event_type, data) tuples identical to _run_single_agent_stream.
|
||||||
|
Do NOT post-process output through _normalize_tagged_list_lines.
|
||||||
|
"""
|
||||||
|
trace_id = _trace_id_from_context(context)
|
||||||
|
today = date.today().isoformat()
|
||||||
|
language = _resolve_language(context)
|
||||||
|
|
||||||
|
raw_template, langfuse_prompt = get_prompt_or_fallback("project_brief", _PROJECT_BRIEF_FALLBACK)
|
||||||
|
system_prompt = compile_prompt(
|
||||||
|
raw_template, langfuse_prompt,
|
||||||
|
language=language, today=today, project_id=project_id,
|
||||||
|
)
|
||||||
|
system_prompt += _relational_memory_injection(context)
|
||||||
|
system_prompt += _proactive_hints_injection(context)
|
||||||
|
system_prompt += _language_instruction(context)
|
||||||
|
if today not in system_prompt:
|
||||||
|
system_prompt += f"\nToday is {today}."
|
||||||
|
|
||||||
|
tools = _build_read_tools(user_id, trace_id)
|
||||||
|
async for event in _run_single_agent_stream(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message=f"Generate the project status brief for project {project_id}.",
|
||||||
|
context=context,
|
||||||
|
langfuse_prompt=langfuse_prompt,
|
||||||
|
agent_name="brief-agent",
|
||||||
|
tools=tools,
|
||||||
|
):
|
||||||
|
yield event
|
||||||
@@ -12,11 +12,14 @@ from typing import Any, Literal
|
|||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.agents.client_agent import CLIENT_TOOLS
|
||||||
from app.agents.note_agent import NOTE_TOOLS
|
from app.agents.note_agent import NOTE_TOOLS
|
||||||
from app.agents.project_agent import PROJECT_TOOLS
|
from app.agents.project_agent import PROJECT_TOOLS
|
||||||
|
from app.agents.relations_agent import make_query_relations_tool
|
||||||
from app.agents.task_agent import TASK_TOOLS
|
from app.agents.task_agent import TASK_TOOLS
|
||||||
from app.agents.timeline_agent import TIMELINE_TOOLS
|
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||||
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
from app.core.agent_session_buffer import session_buffer
|
||||||
|
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
||||||
from app.core.llm import get_agent_llm, model_for_agent
|
from app.core.llm import get_agent_llm, model_for_agent
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
||||||
@@ -24,6 +27,8 @@ from app.db import async_session
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_HISTORY_TURNS = 20
|
||||||
|
|
||||||
FloatingDomainType = Literal["task", "timeline", "project", "node"]
|
FloatingDomainType = Literal["task", "timeline", "project", "node"]
|
||||||
FloatingDomainSection = Literal["task", "timeline", "note"]
|
FloatingDomainSection = Literal["task", "timeline", "note"]
|
||||||
|
|
||||||
@@ -55,6 +60,182 @@ def _language_instruction(context: dict[str, Any]) -> str:
|
|||||||
f"All your output text must be written in {lang}."
|
f"All your output text must be written in {lang}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
MANIFEST_TOKEN_BUDGET = 3000 # rough budget for <linked_folder> block
|
||||||
|
|
||||||
|
|
||||||
|
def format_folder_manifest(manifest: dict | None) -> str:
|
||||||
|
"""Format a folder manifest into the <linked_folder> block.
|
||||||
|
|
||||||
|
Truncates by mtime DESC if estimated tokens exceed MANIFEST_TOKEN_BUDGET.
|
||||||
|
Returns empty string if manifest is None or has no files.
|
||||||
|
"""
|
||||||
|
if not manifest or not manifest.get("files"):
|
||||||
|
return ""
|
||||||
|
files = list(manifest["files"])
|
||||||
|
files.sort(key=lambda f: f.get("mtimeMs", 0), reverse=True)
|
||||||
|
|
||||||
|
header = (
|
||||||
|
f"<linked_folder>\npath: {manifest.get('folderPath', '?')} "
|
||||||
|
f"({len(files)} files, scanned {manifest.get('lastScannedAt', '?')})\nfiles:\n"
|
||||||
|
)
|
||||||
|
footer_template = "… {} more files omitted, use read_project_folder_file to access by path\n</linked_folder>"
|
||||||
|
|
||||||
|
char_budget = MANIFEST_TOKEN_BUDGET * 4 # ~4 chars/token
|
||||||
|
body = ""
|
||||||
|
included = 0
|
||||||
|
for f in files:
|
||||||
|
line = f"- /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}\n"
|
||||||
|
if len(header) + len(body) + len(line) + len(footer_template.format(0)) > char_budget:
|
||||||
|
break
|
||||||
|
body += line
|
||||||
|
included += 1
|
||||||
|
omitted = len(files) - included
|
||||||
|
if omitted > 0:
|
||||||
|
return header + body + footer_template.format(omitted)
|
||||||
|
return header + body + "</linked_folder>"
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_project_manifest(project_id: str) -> dict | None:
|
||||||
|
"""Fetch manifest from Electron via execute_on_client. Returns None if unlinked or error."""
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="read_project_folder_manifest",
|
||||||
|
data={"projectId": project_id},
|
||||||
|
)
|
||||||
|
if not result or not result.get("folderPath"):
|
||||||
|
return None
|
||||||
|
return result
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def build_brief_multi_project_manifest() -> str:
|
||||||
|
"""Build a compact multi-project manifest for the daily brief agent.
|
||||||
|
|
||||||
|
Calls execute_on_client('list_projects_with_folder_manifests') and keeps
|
||||||
|
the top 5 most-recently-modified files per project.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="list_projects_with_folder_manifests",
|
||||||
|
data={},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return ""
|
||||||
|
projects = (result or {}).get("projects") or []
|
||||||
|
if not projects:
|
||||||
|
return ""
|
||||||
|
blocks: list[str] = ["<linked_folders>"]
|
||||||
|
any_entry = False
|
||||||
|
for p in projects:
|
||||||
|
all_files = p.get("files", []) or []
|
||||||
|
files = sorted(all_files, key=lambda f: f.get("mtimeMs", 0), reverse=True)[:5]
|
||||||
|
blocks.append(f"project: {p.get('projectName','?')} [{p.get('projectId','?')}]")
|
||||||
|
blocks.append(f" path: {p.get('folderPath','?')} (scanned {p.get('lastScannedAt','?')})")
|
||||||
|
if not all_files:
|
||||||
|
blocks.append(" (no indexed files yet — folder is linked but empty or unscanned)")
|
||||||
|
else:
|
||||||
|
for f in files:
|
||||||
|
blocks.append(f" - /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}")
|
||||||
|
if len(all_files) > 5:
|
||||||
|
blocks.append(f" … {len(all_files) - 5} more files (use read_project_folder_file by relPath)")
|
||||||
|
any_entry = True
|
||||||
|
if not any_entry:
|
||||||
|
return ""
|
||||||
|
blocks.append("</linked_folders>")
|
||||||
|
return "\n".join(blocks)
|
||||||
|
|
||||||
|
|
||||||
|
def _datetime_context_injection(context: dict[str, Any]) -> str:
|
||||||
|
"""Build a comprehensive DATE CONTEXT block with pre-computed ms-epoch boundaries for common ranges."""
|
||||||
|
fp = context.get("format_prefs")
|
||||||
|
if not fp or not isinstance(fp, dict):
|
||||||
|
return ""
|
||||||
|
try:
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
from datetime import datetime as _dt, timezone as _utc, timedelta as _td
|
||||||
|
|
||||||
|
tz_name: str = str(fp.get("timezone") or "UTC")
|
||||||
|
now_iso: str = str(fp.get("now_iso") or "")
|
||||||
|
date_fmt: str = str(fp.get("date_format") or "dd/MM/yyyy")
|
||||||
|
time_fmt: str = str(fp.get("time_format") or "24h")
|
||||||
|
|
||||||
|
tz = ZoneInfo(tz_name)
|
||||||
|
if now_iso:
|
||||||
|
now_utc = _dt.fromisoformat(now_iso.replace("Z", "+00:00"))
|
||||||
|
else:
|
||||||
|
now_utc = _dt.now(_utc.utc)
|
||||||
|
|
||||||
|
now_ms = int(now_utc.timestamp() * 1000)
|
||||||
|
now_local = now_utc.astimezone(tz)
|
||||||
|
now_local_str = now_local.strftime("%Y-%m-%d %H:%M")
|
||||||
|
weekday_str = now_local.strftime("%A")
|
||||||
|
y, m, d = now_local.year, now_local.month, now_local.day
|
||||||
|
|
||||||
|
def _day(year: int, month: int, day: int) -> tuple[int, int]:
|
||||||
|
s = _dt(year, month, day, tzinfo=tz)
|
||||||
|
e = s + _td(days=1)
|
||||||
|
return int(s.timestamp() * 1000), int(e.timestamp() * 1000) - 1
|
||||||
|
|
||||||
|
def _between(start: "_dt", end_excl: "_dt") -> tuple[int, int]:
|
||||||
|
return int(start.timestamp() * 1000), int(end_excl.timestamp() * 1000) - 1
|
||||||
|
|
||||||
|
today_s, today_e = _day(y, m, d)
|
||||||
|
yd = now_local - _td(days=1)
|
||||||
|
yesterday_s, yesterday_e = _day(yd.year, yd.month, yd.day)
|
||||||
|
tm = now_local + _td(days=1)
|
||||||
|
tomorrow_s, tomorrow_e = _day(tm.year, tm.month, tm.day)
|
||||||
|
|
||||||
|
# ISO week (Mon–Sun)
|
||||||
|
monday = _dt(y, m, d, tzinfo=tz) - _td(days=now_local.weekday())
|
||||||
|
last_monday = monday - _td(weeks=1)
|
||||||
|
next_monday = monday + _td(weeks=1)
|
||||||
|
this_week_s, this_week_e = _between(monday, next_monday)
|
||||||
|
last_week_s, last_week_e = _between(last_monday, monday)
|
||||||
|
next_week_s, next_week_e = _between(next_monday, next_monday + _td(weeks=1))
|
||||||
|
|
||||||
|
# Calendar months
|
||||||
|
this_m_start = _dt(y, m, 1, tzinfo=tz)
|
||||||
|
next_m_start = _dt(y + (m // 12), m % 12 + 1, 1, tzinfo=tz)
|
||||||
|
last_m_start = _dt(y - (1 if m == 1 else 0), 12 if m == 1 else m - 1, 1, tzinfo=tz)
|
||||||
|
next2_m = next_m_start.month % 12 + 1
|
||||||
|
next2_y = next_m_start.year + (1 if next_m_start.month == 12 else 0)
|
||||||
|
next2_m_start = _dt(next2_y, next2_m, 1, tzinfo=tz)
|
||||||
|
this_month_s, this_month_e = _between(this_m_start, next_m_start)
|
||||||
|
last_month_s, last_month_e = _between(last_m_start, this_m_start)
|
||||||
|
next_month_s, next_month_e = _between(next_m_start, next2_m_start)
|
||||||
|
|
||||||
|
# Calendar years
|
||||||
|
this_yr_s, this_yr_e = _between(_dt(y, 1, 1, tzinfo=tz), _dt(y + 1, 1, 1, tzinfo=tz))
|
||||||
|
last_yr_s, last_yr_e = _between(_dt(y - 1, 1, 1, tzinfo=tz), _dt(y, 1, 1, tzinfo=tz))
|
||||||
|
|
||||||
|
sunday = monday + _td(days=6)
|
||||||
|
last_sunday = last_monday + _td(days=6)
|
||||||
|
next_sunday = next_monday + _td(days=6)
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"\n\nDATE CONTEXT (timezone: {tz_name}, dateFormat: {date_fmt}, timeFormat: {time_fmt})\n"
|
||||||
|
f"now_local: {now_local_str} ({weekday_str})\n"
|
||||||
|
f"now_ms: {now_ms}\n\n"
|
||||||
|
f"today [{today_s}, {today_e}] {y:04d}-{m:02d}-{d:02d}\n"
|
||||||
|
f"tomorrow [{tomorrow_s}, {tomorrow_e}] {tm.strftime('%Y-%m-%d')}\n"
|
||||||
|
f"yesterday [{yesterday_s}, {yesterday_e}] {yd.strftime('%Y-%m-%d')}\n"
|
||||||
|
f"this_week [{this_week_s}, {this_week_e}] {monday.strftime('%Y-%m-%d')} → {sunday.strftime('%Y-%m-%d')} (Mon–Sun)\n"
|
||||||
|
f"last_week [{last_week_s}, {last_week_e}] {last_monday.strftime('%Y-%m-%d')} → {last_sunday.strftime('%Y-%m-%d')}\n"
|
||||||
|
f"next_week [{next_week_s}, {next_week_e}] {next_monday.strftime('%Y-%m-%d')} → {next_sunday.strftime('%Y-%m-%d')}\n"
|
||||||
|
f"this_month [{this_month_s}, {this_month_e}] {y:04d}-{m:02d}\n"
|
||||||
|
f"last_month [{last_month_s}, {last_month_e}] {last_m_start.strftime('%Y-%m')}\n"
|
||||||
|
f"next_month [{next_month_s}, {next_month_e}] {next_m_start.strftime('%Y-%m')}\n"
|
||||||
|
f"this_year [{this_yr_s}, {this_yr_e}] {y:04d}\n"
|
||||||
|
f"last_year [{last_yr_s}, {last_yr_e}] {y - 1:04d}\n\n"
|
||||||
|
f"When calling list_tasks_due_today or list_timelines_today, always pass user_timezone=\"{tz_name}\".\n"
|
||||||
|
f"When presenting dates, format using dateFormat={date_fmt} and timeFormat={time_fmt}."
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def _proactive_hints_injection(context: dict[str, Any]) -> str:
|
def _proactive_hints_injection(context: dict[str, Any]) -> str:
|
||||||
"""Return a system-prompt paragraph listing proactive behavioral hints.
|
"""Return a system-prompt paragraph listing proactive behavioral hints.
|
||||||
|
|
||||||
@@ -87,27 +268,203 @@ def _relational_memory_injection(context: dict[str, Any]) -> str:
|
|||||||
return section
|
return section
|
||||||
|
|
||||||
|
|
||||||
_HOME_SYSTEM_PROMPT = (
|
_IDENTITY_KEYS = ("user_name", "job_role", "industry", "primary_use_case", "tone_preference")
|
||||||
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
|
||||||
"Always use tools for factual data retrieval before answering. "
|
|
||||||
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
|
||||||
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
|
||||||
"Return markdown and use tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
|
||||||
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>. "
|
|
||||||
"When listing tasks or timelines, each id tag must be on its own line with no prefix/suffix text. "
|
|
||||||
"Never put titles, priorities, or dates on the same line as <task> or <timeline> tags. "
|
|
||||||
"For questions about upcoming timelines (e.g. 'prossimi eventi'), include only future items in the current month unless the user asks a different range. "
|
|
||||||
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
|
|
||||||
)
|
|
||||||
|
|
||||||
_FLOATING_SYSTEM_PROMPT = (
|
|
||||||
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
def _user_identity_injection(context: dict[str, Any]) -> str:
|
||||||
"Stay focused on the floating scope in context.scope and answer concisely. "
|
"""Return a compact user-profile block from core memory onboarding fields.
|
||||||
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
|
|
||||||
"Always use tools for factual data retrieval before answering. "
|
Returns empty string when no onboarding keys are present.
|
||||||
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
"""
|
||||||
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
core = context.get("core_memory") or {}
|
||||||
)
|
parts: list[str] = []
|
||||||
|
for key in _IDENTITY_KEYS:
|
||||||
|
val = (core.get(key) or "").strip()
|
||||||
|
if val:
|
||||||
|
parts.append(f"- {key}: {val}")
|
||||||
|
if not parts:
|
||||||
|
return ""
|
||||||
|
return "\n\nUser profile:\n" + "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _request_context_block(context: dict[str, Any]) -> str:
|
||||||
|
"""Return a small block with per-request scope and resolved project context."""
|
||||||
|
parts: list[str] = []
|
||||||
|
scope = context.get("scope")
|
||||||
|
if scope and isinstance(scope, dict):
|
||||||
|
parts.append(f"scope: {json.dumps(scope, ensure_ascii=True)}")
|
||||||
|
resolved = context.get("resolved_project_id")
|
||||||
|
if resolved and isinstance(resolved, str):
|
||||||
|
parts.append(f"resolved_project_id: {resolved}")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
_HOME_SYSTEM_PROMPT = """\
|
||||||
|
You are adiuvAI's home executive assistant.{user_identity}
|
||||||
|
You are not a chatbot — you are a proactive partner who runs ahead of the user, anticipates what they need next, and closes every reply with a concrete next step or a clarifying question.
|
||||||
|
|
||||||
|
# How you work
|
||||||
|
- Use tools before answering anything factual. Never guess counts, dates, or status.
|
||||||
|
- Prefer parallel tool calls when the questions are independent (e.g. counts per status). Chain calls when one result feeds the next.
|
||||||
|
- After delivering the answer, propose the next useful action: a follow-up task to draft, a deadline at risk, a project to triage, a person to remind. Use what you know about the user (job role, industry, primary use case) to make the suggestion relevant.
|
||||||
|
- Match the user's tone preference. Default to warm-but-direct; stay concise.
|
||||||
|
- When the user asks to remember, forget, or update something, use memory tools.
|
||||||
|
|
||||||
|
# Filter discipline
|
||||||
|
- Never set the `assignee` filter on list_tasks/count_tasks unless the user explicitly names a person ("Marco's tasks") or refers to themselves ("my tasks", "assigned to me", "mine").
|
||||||
|
- The user's own name in the User profile block is for context only — it is NOT a default filter.
|
||||||
|
- When in doubt, omit `assignee` and return the global result.
|
||||||
|
|
||||||
|
# Output format
|
||||||
|
Return markdown. Reference entities with these tags exactly — one id per tag, each tag on its own line, no prefix/suffix text on the same line:
|
||||||
|
<project>id</project> <task>id</task> <note>id</note> <timeline>id</timeline>
|
||||||
|
|
||||||
|
When the answer contains a list of entities (any of the tags above), structure the reply as three blocks separated by blank lines:
|
||||||
|
1. One short intro line stating what is coming (count + scope, e.g. "Ecco i tuoi 18 task ad alta priorità:"). Match the user's language.
|
||||||
|
2. All entity tags, one per line, consecutive, no prose interleaved. Do NOT put titles, dates, priorities, or any descriptive text on the same line as a tag or between tags.
|
||||||
|
3. One short closing recap (1–2 sentences) that points out a pattern, risk, or insight noticed in the list, and ends with a concrete next step or clarifying question.
|
||||||
|
|
||||||
|
For single-entity answers skip blocks 1 and 3 if they would be redundant; just emit the tag.
|
||||||
|
|
||||||
|
For analytical answers (status overviews, breakdowns by category/priority/project, comparisons, trends, "resoconto", "panoramica") consider returning a chart block when it communicates the answer faster than prose. The decision is yours — skip charts for trivial single-number answers. Schema:
|
||||||
|
<chart>{{"chartType":"pie|bar|line|area|radar|radial","title":"...","data":[{{"name":"...","value":N}},...], "config":{{"value":{{"label":"...","color":"var(--chart-1)"}} }} }}</chart>
|
||||||
|
- pie for share-of-total breakdowns; bar for category comparisons; line/area for time series; radar for multi-dimension.
|
||||||
|
- data rows must include a "name" field; numeric series keys must match config keys.
|
||||||
|
- Use var(--chart-1) through var(--chart-5) for colors, cycling 1-5 in series order. Do NOT wrap in hsl() or oklch() — these are complete CSS values already.
|
||||||
|
|
||||||
|
For upcoming-timeline questions ("prossimi eventi"), include only future items in the current month unless the user asks otherwise.
|
||||||
|
|
||||||
|
# Date filtering
|
||||||
|
{date_context}
|
||||||
|
|
||||||
|
When filtering tasks/timelines/notes by date, take dueDateFrom / dueDateTo (ms epoch UTC) verbatim from the DATE CONTEXT boundary table above. Do NOT compute boundaries from now_ms yourself.
|
||||||
|
For specific dates not listed, compute local-midnight in the user timezone and convert to UTC ms.
|
||||||
|
For "today" / "tomorrow" queries, prefer list_tasks_due_today / list_timelines_today with user_timezone from DATE CONTEXT.
|
||||||
|
|
||||||
|
# Language
|
||||||
|
{language_instruction}
|
||||||
|
|
||||||
|
# Known people & projects
|
||||||
|
{relational_memory}
|
||||||
|
|
||||||
|
# Behavioral hints
|
||||||
|
{proactive_hints}
|
||||||
|
|
||||||
|
# Request context
|
||||||
|
{request_context}\
|
||||||
|
"""
|
||||||
|
|
||||||
|
_FLOATING_SYSTEM_PROMPT = """\
|
||||||
|
You are adiuvAI's floating executive assistant.{user_identity}
|
||||||
|
You are pinned to a specific entity (task, timeline event, project, or note) and you stay strictly within that scope.
|
||||||
|
Be a proactive partner: anticipate the next useful action and close with a concrete suggestion or a clarifying question — but stay terse, one short paragraph at most.
|
||||||
|
|
||||||
|
# How you work
|
||||||
|
- Use tools before answering anything factual. Never guess.
|
||||||
|
- Stay in the floating scope (see Request context). If the user asks something outside scope, answer briefly and suggest opening the home assistant.
|
||||||
|
- Match the user's tone preference. Default to warm-but-direct.
|
||||||
|
- When the user asks to remember, forget, or update something, use memory tools.
|
||||||
|
|
||||||
|
# Filter discipline
|
||||||
|
- Never set the `assignee` filter on list_tasks/count_tasks unless the user explicitly names a person ("Marco's tasks") or refers to themselves ("my tasks", "assigned to me", "mine").
|
||||||
|
- The user's own name in the User profile block is for context only — it is NOT a default filter.
|
||||||
|
- When in doubt, omit `assignee` and return the global result.
|
||||||
|
|
||||||
|
# Output format
|
||||||
|
Plain text only. Do NOT output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed-id wrappers, and do NOT output <chart> blocks — those are for the home assistant.
|
||||||
|
|
||||||
|
# Date filtering
|
||||||
|
{date_context}
|
||||||
|
|
||||||
|
When filtering by date, take dueDateFrom / dueDateTo (ms epoch UTC) verbatim from the DATE CONTEXT boundary table above. Do NOT compute boundaries from now_ms yourself.
|
||||||
|
For specific dates not listed, compute local-midnight in the user timezone and convert to UTC ms.
|
||||||
|
|
||||||
|
# Language
|
||||||
|
{language_instruction}
|
||||||
|
|
||||||
|
# Known people & projects
|
||||||
|
{relational_memory}
|
||||||
|
|
||||||
|
# Behavioral hints
|
||||||
|
{proactive_hints}
|
||||||
|
|
||||||
|
# Request context
|
||||||
|
{request_context}\
|
||||||
|
"""
|
||||||
|
|
||||||
|
_TASK_BRIEF_RESEARCH_SYSTEM_PROMPT = """\
|
||||||
|
You are an executive assistant preparing a briefing dossier for your principal before they act on a specific task.
|
||||||
|
Your job: gather all relevant context, synthesize it into a tight actionable dossier, and — if the task requires writing (email, message, document) — produce a ready-to-use draft.{user_identity}
|
||||||
|
|
||||||
|
# Research workflow
|
||||||
|
Follow these steps in order, using tools:
|
||||||
|
1. Read the task fully (title, description, due date, priority, status, project, comments).
|
||||||
|
2. Fetch the parent project (`get_project`) to understand scope, aiSummary, and any linked client.
|
||||||
|
3. If the project has a clientId: call `get_client(id)` to retrieve full client details.
|
||||||
|
4. Call `query_relations` (subject_label=client_name or task subject) to find cross-project connections — e.g. the same client appearing in multiple projects.
|
||||||
|
5. Search associative memory (`search_associative`) and archival memory (`archival_memory_search`) using the task title + client name as query phrases to surface relevant past interactions.
|
||||||
|
6. Read core memory blocks for tone preference, language, and user style: `memory_get("tone_preference")`, `memory_get("language")`.
|
||||||
|
7. Determine task kind: is this a writing task (email reply, message, follow-up, proposal)? If yes, draft a ready-to-send piece.
|
||||||
|
|
||||||
|
# Output structure
|
||||||
|
Write the briefing in the user's language. Use this exact structure:
|
||||||
|
|
||||||
|
**What needs to be done**
|
||||||
|
(1–2 sentences, concrete and specific — what action the user must take)
|
||||||
|
|
||||||
|
**Context you should know**
|
||||||
|
(bullet points covering: client background, related projects, prior interactions, tone/style notes, any relevant deadlines or dependencies)
|
||||||
|
|
||||||
|
**Suggested first step**
|
||||||
|
(one specific, immediately actionable instruction)
|
||||||
|
|
||||||
|
If this is a writing task, append a canvas block at the very end:
|
||||||
|
<canvas kind="email|document|message">
|
||||||
|
...ready-to-use draft here...
|
||||||
|
</canvas>
|
||||||
|
|
||||||
|
Do NOT include the canvas block for non-writing tasks.
|
||||||
|
Do NOT repeat verbatim task fields the user already sees in the UI.
|
||||||
|
Be concrete — no vague advice. Every bullet should be a fact that changes what the user does.
|
||||||
|
|
||||||
|
# Date context
|
||||||
|
{date_context}
|
||||||
|
|
||||||
|
# Language
|
||||||
|
{language_instruction}
|
||||||
|
|
||||||
|
# Known people & projects
|
||||||
|
{relational_memory}
|
||||||
|
|
||||||
|
# Request context
|
||||||
|
{request_context}\
|
||||||
|
"""
|
||||||
|
|
||||||
|
_TASK_BRIEF_FOLLOWUP_SYSTEM_PROMPT = """\
|
||||||
|
You are an executive assistant continuing a conversation with your principal.
|
||||||
|
You have already prepared and delivered a research briefing for the active task. The user has read it.{user_identity}
|
||||||
|
|
||||||
|
Your briefing:
|
||||||
|
---
|
||||||
|
{briefing_context}
|
||||||
|
---
|
||||||
|
|
||||||
|
Continue from here. Do NOT repeat the briefing. Refer to it when relevant.
|
||||||
|
Help the user execute: edit drafts, refine wording, look up additional details, plan next steps.
|
||||||
|
Stay terse — your principal is a busy executive.
|
||||||
|
|
||||||
|
# Date context
|
||||||
|
{date_context}
|
||||||
|
|
||||||
|
# Language
|
||||||
|
{language_instruction}
|
||||||
|
|
||||||
|
# Known people & projects
|
||||||
|
{relational_memory}
|
||||||
|
|
||||||
|
# Request context
|
||||||
|
{request_context}\
|
||||||
|
"""
|
||||||
|
|
||||||
_FLOATING_DOMAIN_CLASSIFIER_PROMPT = (
|
_FLOATING_DOMAIN_CLASSIFIER_PROMPT = (
|
||||||
"You are a strict domain classifier for websocket floating requests. "
|
"You are a strict domain classifier for websocket floating requests. "
|
||||||
@@ -217,10 +574,19 @@ def _session_id_from_context(context: dict[str, Any]) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
|
def _build_system_prompt(name: str, fallback: str, context: dict[str, Any]) -> tuple[str, Any]:
|
||||||
sanitized = dict(context)
|
"""Fetch Langfuse template and compile all per-request slots into one system prompt."""
|
||||||
sanitized.pop("_debug", None)
|
template, prompt_obj = get_prompt_or_fallback(name, fallback)
|
||||||
return sanitized
|
text = compile_prompt(
|
||||||
|
template, prompt_obj,
|
||||||
|
date_context=_datetime_context_injection(context).strip(),
|
||||||
|
language_instruction=_language_instruction(context).strip(),
|
||||||
|
user_identity=_user_identity_injection(context).strip(),
|
||||||
|
relational_memory=_relational_memory_injection(context).strip(),
|
||||||
|
proactive_hints=_proactive_hints_injection(context).strip(),
|
||||||
|
request_context=_request_context_block(context),
|
||||||
|
)
|
||||||
|
return text, prompt_obj
|
||||||
|
|
||||||
|
|
||||||
_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]</\1>")
|
_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]</\1>")
|
||||||
@@ -476,6 +842,25 @@ def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
|||||||
lines = [f"- {item}" for item in results]
|
lines = [f"- {item}" for item in results]
|
||||||
return "Recall memory results:\n" + "\n".join(lines)
|
return "Recall memory results:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def search_associative(query: str, limit: int = 5) -> str:
|
||||||
|
"""Semantic search across associative (archival) memory for a given query.
|
||||||
|
|
||||||
|
Use this to surface long-term memories related to a topic, client, or task
|
||||||
|
that may not appear in recent episodes.
|
||||||
|
|
||||||
|
query: natural-language search phrase.
|
||||||
|
limit: max results (default 5).
|
||||||
|
"""
|
||||||
|
logger.info("deep_agent: search_associative trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
results = await memory.search_archival(user_id, query, top_k=limit)
|
||||||
|
if not results:
|
||||||
|
return "No associative memory results found."
|
||||||
|
lines = [f"- {item}" for item in results]
|
||||||
|
return "Associative memory results:\n" + "\n".join(lines)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
memory_list_blocks,
|
memory_list_blocks,
|
||||||
memory_get,
|
memory_get,
|
||||||
@@ -486,6 +871,30 @@ def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
|||||||
archival_memory_insert,
|
archival_memory_insert,
|
||||||
archival_memory_search,
|
archival_memory_search,
|
||||||
conversation_search,
|
conversation_search,
|
||||||
|
search_associative,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _read_only_memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
|
"""Return memory tools that only read — safe for the read-only brief-agent subset."""
|
||||||
|
all_mem = _memory_tools(user_id, trace_id)
|
||||||
|
_read_names = {
|
||||||
|
"memory_list_blocks", "memory_get", "archival_memory_search",
|
||||||
|
"conversation_search", "search_associative",
|
||||||
|
}
|
||||||
|
return [t for t in all_mem if t.name in _read_names]
|
||||||
|
|
||||||
|
|
||||||
|
def _brief_research_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||||
|
"""Return the full tool palette for Stage-1 task brief research (read-only)."""
|
||||||
|
return [
|
||||||
|
*TASK_TOOLS,
|
||||||
|
*PROJECT_TOOLS,
|
||||||
|
*NOTE_TOOLS,
|
||||||
|
*TIMELINE_TOOLS,
|
||||||
|
*CLIENT_TOOLS,
|
||||||
|
*_read_only_memory_tools(user_id, trace_id),
|
||||||
|
make_query_relations_tool(user_id, trace_id),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -655,6 +1064,23 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[
|
|||||||
return _infer_floating_domain_rule_based(message, context)
|
return _infer_floating_domain_rule_based(message, context)
|
||||||
|
|
||||||
|
|
||||||
|
def _history_to_messages(history: list[dict[str, str]] | None) -> list[Any]:
|
||||||
|
if not history:
|
||||||
|
return []
|
||||||
|
turns = history[-MAX_HISTORY_TURNS:]
|
||||||
|
result: list[Any] = []
|
||||||
|
for turn in turns:
|
||||||
|
role = turn.get("role", "")
|
||||||
|
content = turn.get("content", "")
|
||||||
|
if not content:
|
||||||
|
continue
|
||||||
|
if role == "user":
|
||||||
|
result.append(HumanMessage(content=content))
|
||||||
|
elif role == "assistant":
|
||||||
|
result.append(AIMessage(content=content))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def _run_single_agent(
|
async def _run_single_agent(
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
@@ -664,23 +1090,21 @@ async def _run_single_agent(
|
|||||||
max_steps: int = 6,
|
max_steps: int = 6,
|
||||||
langfuse_prompt: Any = None,
|
langfuse_prompt: Any = None,
|
||||||
agent_name: str = "agent",
|
agent_name: str = "agent",
|
||||||
|
conversation_history: list[dict[str, str]] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
trace_id = _trace_id_from_context(context)
|
trace_id = _trace_id_from_context(context)
|
||||||
session_id = _session_id_from_context(context)
|
session_id = _session_id_from_context(context)
|
||||||
lf = get_langfuse()
|
lf = get_langfuse()
|
||||||
llm = get_agent_llm(agent_name)
|
llm = get_agent_llm(agent_name)
|
||||||
tools = _all_tools_for_user(user_id, trace_id)
|
tools = _all_tools_for_user(user_id, trace_id)
|
||||||
model_context = _context_for_model(context)
|
|
||||||
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
||||||
llm_with_tools = llm.bind_tools(tools)
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
_buffered = session_buffer.get(user_id, session_id) if session_id else None
|
||||||
|
history_messages = _buffered if _buffered is not None else _history_to_messages(conversation_history)
|
||||||
messages: list[Any] = [
|
messages: list[Any] = [
|
||||||
SystemMessage(content=system_prompt),
|
SystemMessage(content=system_prompt),
|
||||||
HumanMessage(
|
*history_messages,
|
||||||
content=(
|
HumanMessage(content=message),
|
||||||
f"User message:\n{message}\n\n"
|
|
||||||
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
|
||||||
)
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
tool_calls_count = 0
|
tool_calls_count = 0
|
||||||
@@ -692,7 +1116,7 @@ async def _run_single_agent(
|
|||||||
|
|
||||||
_span_ctx = (
|
_span_ctx = (
|
||||||
lf.start_as_current_observation(
|
lf.start_as_current_observation(
|
||||||
as_type="span",
|
as_type="agent",
|
||||||
name=agent_name,
|
name=agent_name,
|
||||||
metadata={"user_id": user_id, "session_id": trace_id},
|
metadata={"user_id": user_id, "session_id": trace_id},
|
||||||
input=message,
|
input=message,
|
||||||
@@ -700,6 +1124,7 @@ async def _run_single_agent(
|
|||||||
if lf else None
|
if lf else None
|
||||||
)
|
)
|
||||||
_span = _span_ctx.__enter__() if _span_ctx else None
|
_span = _span_ctx.__enter__() if _span_ctx else None
|
||||||
|
_messages_to_save: list[Any] | None = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for _ in range(max_steps):
|
for _ in range(max_steps):
|
||||||
@@ -732,6 +1157,7 @@ async def _run_single_agent(
|
|||||||
)
|
)
|
||||||
if _span:
|
if _span:
|
||||||
_span.update(output=final_text)
|
_span.update(output=final_text)
|
||||||
|
_messages_to_save = messages[1:] # strip SystemMessage; save full tool history
|
||||||
return final_text
|
return final_text
|
||||||
|
|
||||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
@@ -750,6 +1176,14 @@ async def _run_single_agent(
|
|||||||
tool_fn = tool_map.get(call_name)
|
tool_fn = tool_map.get(call_name)
|
||||||
if tool_fn is None:
|
if tool_fn is None:
|
||||||
tool_output = f"Unknown tool: {call_name}"
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
elif lf:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="tool",
|
||||||
|
name=call_name,
|
||||||
|
input=call_args,
|
||||||
|
) as tool_obs:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
tool_obs.update(output=str(tool_output)[:8000])
|
||||||
else:
|
else:
|
||||||
tool_output = await tool_fn.ainvoke(call_args)
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
@@ -764,6 +1198,7 @@ async def _run_single_agent(
|
|||||||
|
|
||||||
final = await llm.ainvoke(messages)
|
final = await llm.ainvoke(messages)
|
||||||
final_text = _as_text(final.content)
|
final_text = _as_text(final.content)
|
||||||
|
messages.append(AIMessage(content=final_text))
|
||||||
logger.info(
|
logger.info(
|
||||||
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
||||||
trace_id or "-",
|
trace_id or "-",
|
||||||
@@ -773,8 +1208,11 @@ async def _run_single_agent(
|
|||||||
)
|
)
|
||||||
if _span:
|
if _span:
|
||||||
_span.update(output=final_text)
|
_span.update(output=final_text)
|
||||||
|
_messages_to_save = messages[1:]
|
||||||
return final_text
|
return final_text
|
||||||
finally:
|
finally:
|
||||||
|
if session_id and _messages_to_save is not None:
|
||||||
|
session_buffer.set(user_id, session_id, _messages_to_save)
|
||||||
clear_tool_result_collector()
|
clear_tool_result_collector()
|
||||||
if _span_ctx:
|
if _span_ctx:
|
||||||
_span_ctx.__exit__(None, None, None)
|
_span_ctx.__exit__(None, None, None)
|
||||||
@@ -792,23 +1230,23 @@ async def _run_single_agent_stream(
|
|||||||
max_steps: int = 6,
|
max_steps: int = 6,
|
||||||
langfuse_prompt: Any = None,
|
langfuse_prompt: Any = None,
|
||||||
agent_name: str = "agent",
|
agent_name: str = "agent",
|
||||||
|
tools: list[Any] | None = None,
|
||||||
|
conversation_history: list[dict[str, str]] | None = None,
|
||||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
trace_id = _trace_id_from_context(context)
|
trace_id = _trace_id_from_context(context)
|
||||||
session_id = _session_id_from_context(context)
|
session_id = _session_id_from_context(context)
|
||||||
lf = get_langfuse()
|
lf = get_langfuse()
|
||||||
llm = get_agent_llm(agent_name)
|
llm = get_agent_llm(agent_name)
|
||||||
tools = _all_tools_for_user(user_id, trace_id)
|
if tools is None:
|
||||||
model_context = _context_for_model(context)
|
tools = _all_tools_for_user(user_id, trace_id)
|
||||||
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
||||||
llm_with_tools = llm.bind_tools(tools)
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
_buffered = session_buffer.get(user_id, session_id) if session_id else None
|
||||||
|
history_messages = _buffered if _buffered is not None else _history_to_messages(conversation_history)
|
||||||
messages: list[Any] = [
|
messages: list[Any] = [
|
||||||
SystemMessage(content=system_prompt),
|
SystemMessage(content=system_prompt),
|
||||||
HumanMessage(
|
*history_messages,
|
||||||
content=(
|
HumanMessage(content=message),
|
||||||
f"User message:\n{message}\n\n"
|
|
||||||
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
|
||||||
)
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
tool_calls_count = 0
|
tool_calls_count = 0
|
||||||
@@ -821,7 +1259,7 @@ async def _run_single_agent_stream(
|
|||||||
|
|
||||||
_span_ctx = (
|
_span_ctx = (
|
||||||
lf.start_as_current_observation(
|
lf.start_as_current_observation(
|
||||||
as_type="span",
|
as_type="agent",
|
||||||
name=f"{agent_name}-stream",
|
name=f"{agent_name}-stream",
|
||||||
metadata={"user_id": user_id, "session_id": trace_id},
|
metadata={"user_id": user_id, "session_id": trace_id},
|
||||||
input=message,
|
input=message,
|
||||||
@@ -830,6 +1268,7 @@ async def _run_single_agent_stream(
|
|||||||
)
|
)
|
||||||
_span = _span_ctx.__enter__() if _span_ctx else None
|
_span = _span_ctx.__enter__() if _span_ctx else None
|
||||||
streamed_text: list[str] = []
|
streamed_text: list[str] = []
|
||||||
|
_messages_to_save: list[Any] | None = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for _ in range(max_steps):
|
for _ in range(max_steps):
|
||||||
@@ -849,25 +1288,15 @@ async def _run_single_agent_stream(
|
|||||||
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||||
_gen_ctx.__exit__(None, None, None)
|
_gen_ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
if not response.tool_calls:
|
||||||
emitted_any = False
|
# Yield the content from the ainvoke response directly — no second LLM call.
|
||||||
async for chunk in llm.astream(messages):
|
# Previously, messages.append(response) was called first, so the re-stream
|
||||||
token = _as_text(getattr(chunk, "content", ""))
|
# received [System, Human, AI] and regenerated a response without tools bound.
|
||||||
if token:
|
final_text = _as_text(response.content)
|
||||||
streamed_chars += len(token)
|
if final_text:
|
||||||
streamed_text.append(token)
|
streamed_chars += len(final_text)
|
||||||
emitted_any = True
|
streamed_text.append(final_text)
|
||||||
yield "token", token
|
yield "token", final_text
|
||||||
|
|
||||||
# Some providers return final text in `response.content` but stream no chunks.
|
|
||||||
if not emitted_any:
|
|
||||||
fallback_text = _as_text(response.content)
|
|
||||||
if fallback_text:
|
|
||||||
streamed_chars += len(fallback_text)
|
|
||||||
streamed_text.append(fallback_text)
|
|
||||||
yield "token", fallback_text
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
||||||
trace_id or "-",
|
trace_id or "-",
|
||||||
@@ -877,8 +1306,11 @@ async def _run_single_agent_stream(
|
|||||||
)
|
)
|
||||||
if _span:
|
if _span:
|
||||||
_span.update(output="".join(streamed_text))
|
_span.update(output="".join(streamed_text))
|
||||||
|
messages.append(response)
|
||||||
|
_messages_to_save = messages[1:] # strip SystemMessage
|
||||||
return
|
return
|
||||||
|
|
||||||
|
messages.append(response)
|
||||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
for call in response.tool_calls:
|
for call in response.tool_calls:
|
||||||
tool_calls_count += 1
|
tool_calls_count += 1
|
||||||
@@ -895,6 +1327,14 @@ async def _run_single_agent_stream(
|
|||||||
tool_fn = tool_map.get(call_name)
|
tool_fn = tool_map.get(call_name)
|
||||||
if tool_fn is None:
|
if tool_fn is None:
|
||||||
tool_output = f"Unknown tool: {call_name}"
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
elif lf:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="tool",
|
||||||
|
name=call_name,
|
||||||
|
input=call_args,
|
||||||
|
) as tool_obs:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
tool_obs.update(output=str(tool_output)[:8000])
|
||||||
else:
|
else:
|
||||||
tool_output = await tool_fn.ainvoke(call_args)
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
@@ -907,12 +1347,16 @@ async def _run_single_agent_stream(
|
|||||||
|
|
||||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
fallback_chunks: list[str] = []
|
||||||
async for chunk in llm.astream(messages):
|
async for chunk in llm.astream(messages):
|
||||||
token = _as_text(getattr(chunk, "content", ""))
|
token = _as_text(getattr(chunk, "content", ""))
|
||||||
if token:
|
if token:
|
||||||
streamed_chars += len(token)
|
streamed_chars += len(token)
|
||||||
streamed_text.append(token)
|
streamed_text.append(token)
|
||||||
|
fallback_chunks.append(token)
|
||||||
yield "token", token
|
yield "token", token
|
||||||
|
messages.append(AIMessage(content="".join(fallback_chunks)))
|
||||||
|
_messages_to_save = messages[1:]
|
||||||
logger.info(
|
logger.info(
|
||||||
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
||||||
trace_id or "-",
|
trace_id or "-",
|
||||||
@@ -923,6 +1367,8 @@ async def _run_single_agent_stream(
|
|||||||
if _span:
|
if _span:
|
||||||
_span.update(output="".join(streamed_text))
|
_span.update(output="".join(streamed_text))
|
||||||
finally:
|
finally:
|
||||||
|
if session_id and _messages_to_save is not None:
|
||||||
|
session_buffer.set(user_id, session_id, _messages_to_save)
|
||||||
clear_tool_result_collector()
|
clear_tool_result_collector()
|
||||||
if _span_ctx:
|
if _span_ctx:
|
||||||
_span_ctx.__exit__(None, None, None)
|
_span_ctx.__exit__(None, None, None)
|
||||||
@@ -933,12 +1379,7 @@ async def _run_single_agent_stream(
|
|||||||
|
|
||||||
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
||||||
prepared_context = await _prepare_context(message, context)
|
prepared_context = await _prepare_context(message, context)
|
||||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
system_prompt, langfuse_prompt = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, prepared_context)
|
||||||
"home_system", _HOME_SYSTEM_PROMPT
|
|
||||||
)
|
|
||||||
system_prompt += _relational_memory_injection(context)
|
|
||||||
system_prompt += _proactive_hints_injection(context)
|
|
||||||
system_prompt += _language_instruction(context)
|
|
||||||
response = await _run_single_agent(
|
response = await _run_single_agent(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
@@ -946,6 +1387,7 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
|||||||
context=prepared_context,
|
context=prepared_context,
|
||||||
langfuse_prompt=langfuse_prompt,
|
langfuse_prompt=langfuse_prompt,
|
||||||
agent_name="home-agent",
|
agent_name="home-agent",
|
||||||
|
conversation_history=context.get("conversation_history"),
|
||||||
)
|
)
|
||||||
return _normalize_tagged_list_lines(response, message)
|
return _normalize_tagged_list_lines(response, message)
|
||||||
|
|
||||||
@@ -953,12 +1395,7 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
|||||||
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
|
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
|
||||||
prepared_context = await _prepare_context(message, context)
|
prepared_context = await _prepare_context(message, context)
|
||||||
domain = await _infer_floating_domain(message, prepared_context)
|
domain = await _infer_floating_domain(message, prepared_context)
|
||||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
system_prompt, langfuse_prompt = _build_system_prompt("floating_system", _FLOATING_SYSTEM_PROMPT, prepared_context)
|
||||||
"floating_system", _FLOATING_SYSTEM_PROMPT
|
|
||||||
)
|
|
||||||
system_prompt += _relational_memory_injection(context)
|
|
||||||
system_prompt += _proactive_hints_injection(context)
|
|
||||||
system_prompt += _language_instruction(context)
|
|
||||||
response = await _run_single_agent(
|
response = await _run_single_agent(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
@@ -966,6 +1403,7 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t
|
|||||||
context=prepared_context,
|
context=prepared_context,
|
||||||
langfuse_prompt=langfuse_prompt,
|
langfuse_prompt=langfuse_prompt,
|
||||||
agent_name="floating-agent",
|
agent_name="floating-agent",
|
||||||
|
conversation_history=context.get("conversation_history"),
|
||||||
)
|
)
|
||||||
sanitized = _strip_floating_markup(response)
|
sanitized = _strip_floating_markup(response)
|
||||||
if not sanitized and response:
|
if not sanitized and response:
|
||||||
@@ -977,14 +1415,26 @@ async def run_home_stream(
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
context: dict[str, Any],
|
context: dict[str, Any],
|
||||||
|
project_id: str | None = None,
|
||||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
from app.agents.folder_agent import FOLDER_TOOLS
|
||||||
|
|
||||||
prepared_context = await _prepare_context(message, context)
|
prepared_context = await _prepare_context(message, context)
|
||||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
system_prompt, langfuse_prompt = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, prepared_context)
|
||||||
"home_system", _HOME_SYSTEM_PROMPT
|
|
||||||
)
|
manifest_block = ""
|
||||||
system_prompt += _relational_memory_injection(context)
|
if project_id:
|
||||||
system_prompt += _proactive_hints_injection(context)
|
manifest = await _fetch_project_manifest(project_id)
|
||||||
system_prompt += _language_instruction(context)
|
manifest_block = format_folder_manifest(manifest)
|
||||||
|
if not manifest_block:
|
||||||
|
# No specific project context — surface all linked folders so the agent
|
||||||
|
# can answer questions like "tell me about project X" using its files.
|
||||||
|
manifest_block = await build_brief_multi_project_manifest()
|
||||||
|
system_prompt = system_prompt + ("\n\n" + manifest_block if manifest_block else "")
|
||||||
|
|
||||||
|
trace_id = _trace_id_from_context(prepared_context)
|
||||||
|
tools = [*_all_tools_for_user(user_id, trace_id), *FOLDER_TOOLS]
|
||||||
|
|
||||||
text_chunks: list[str] = []
|
text_chunks: list[str] = []
|
||||||
async for event in _run_single_agent_stream(
|
async for event in _run_single_agent_stream(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -993,6 +1443,8 @@ async def run_home_stream(
|
|||||||
context=prepared_context,
|
context=prepared_context,
|
||||||
langfuse_prompt=langfuse_prompt,
|
langfuse_prompt=langfuse_prompt,
|
||||||
agent_name="home-agent",
|
agent_name="home-agent",
|
||||||
|
tools=tools,
|
||||||
|
conversation_history=context.get("conversation_history"),
|
||||||
):
|
):
|
||||||
event_type, data = event
|
event_type, data = event
|
||||||
if event_type != "token":
|
if event_type != "token":
|
||||||
@@ -1014,12 +1466,29 @@ async def run_floating_stream(
|
|||||||
domain = await _infer_floating_domain(message, prepared_context)
|
domain = await _infer_floating_domain(message, prepared_context)
|
||||||
yield "floating_domain", domain
|
yield "floating_domain", domain
|
||||||
|
|
||||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
brief_mode: bool = bool(context.get("brief_mode"))
|
||||||
"floating_system", _FLOATING_SYSTEM_PROMPT
|
briefing_context_text: str = str(context.get("briefing_context") or "").strip()
|
||||||
)
|
|
||||||
system_prompt += _relational_memory_injection(context)
|
if brief_mode and briefing_context_text:
|
||||||
system_prompt += _proactive_hints_injection(context)
|
# Stage 2: inject briefing as ground truth context.
|
||||||
system_prompt += _language_instruction(context)
|
# Pre-substitute {briefing_context} in the template (handles both Langfuse {{}} and fallback {})
|
||||||
|
# before compile_prompt sees the remaining standard variables.
|
||||||
|
template, langfuse_prompt = get_prompt_or_fallback(
|
||||||
|
"task_brief_followup_system",
|
||||||
|
_TASK_BRIEF_FOLLOWUP_SYSTEM_PROMPT,
|
||||||
|
)
|
||||||
|
system_prompt = compile_prompt(
|
||||||
|
template, langfuse_prompt,
|
||||||
|
date_context=_datetime_context_injection(prepared_context).strip(),
|
||||||
|
language_instruction=_language_instruction(prepared_context).strip(),
|
||||||
|
user_identity=_user_identity_injection(prepared_context).strip(),
|
||||||
|
relational_memory=_relational_memory_injection(prepared_context).strip(),
|
||||||
|
proactive_hints=_proactive_hints_injection(prepared_context).strip(),
|
||||||
|
request_context=_request_context_block(prepared_context),
|
||||||
|
briefing_context=briefing_context_text,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
system_prompt, langfuse_prompt = _build_system_prompt("floating_system", _FLOATING_SYSTEM_PROMPT, prepared_context)
|
||||||
sanitizer = _FloatingStreamSanitizer()
|
sanitizer = _FloatingStreamSanitizer()
|
||||||
emitted_sanitized = False
|
emitted_sanitized = False
|
||||||
raw_chunks: list[str] = []
|
raw_chunks: list[str] = []
|
||||||
@@ -1030,6 +1499,7 @@ async def run_floating_stream(
|
|||||||
context=prepared_context,
|
context=prepared_context,
|
||||||
langfuse_prompt=langfuse_prompt,
|
langfuse_prompt=langfuse_prompt,
|
||||||
agent_name="floating-agent",
|
agent_name="floating-agent",
|
||||||
|
conversation_history=context.get("conversation_history"),
|
||||||
):
|
):
|
||||||
event_type, data = event
|
event_type, data = event
|
||||||
if event_type != "token":
|
if event_type != "token":
|
||||||
@@ -1052,6 +1522,58 @@ async def run_floating_stream(
|
|||||||
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
|
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
|
||||||
|
|
||||||
|
|
||||||
|
async def run_task_brief_research_stream(
|
||||||
|
user_id: str,
|
||||||
|
task_id: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
project_id: str | None = None,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
"""Stage-1 executive assistant: deep research for one task.
|
||||||
|
|
||||||
|
Yields ``("token", chunk)`` events like other stream runners.
|
||||||
|
The final concatenated text may contain a ``<canvas kind="...">...</canvas>`` block
|
||||||
|
which the WS handler strips and emits as a ``canvas_draft`` mutation.
|
||||||
|
"""
|
||||||
|
from app.agents.folder_agent import FOLDER_TOOLS
|
||||||
|
|
||||||
|
prepared_context = await _prepare_context(f"task:{task_id}", context)
|
||||||
|
tools = [*_brief_research_tools(user_id, _trace_id_from_context(prepared_context)), *FOLDER_TOOLS]
|
||||||
|
|
||||||
|
# Inject task_id so the agent knows what to look up first.
|
||||||
|
research_message = (
|
||||||
|
f"Prepare a briefing dossier for task ID: {task_id}\n"
|
||||||
|
"Follow the research workflow: read the task, then project, then client, "
|
||||||
|
"then cross-project relations, then relevant memory. "
|
||||||
|
"End with a concrete suggested first step. "
|
||||||
|
"If this is a writing task, include a <canvas kind=\"...\"> draft."
|
||||||
|
)
|
||||||
|
|
||||||
|
system_prompt, langfuse_prompt = _build_system_prompt(
|
||||||
|
"task_brief_research_system",
|
||||||
|
_TASK_BRIEF_RESEARCH_SYSTEM_PROMPT,
|
||||||
|
prepared_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
manifest_block = ""
|
||||||
|
if project_id:
|
||||||
|
manifest = await _fetch_project_manifest(project_id)
|
||||||
|
manifest_block = format_folder_manifest(manifest)
|
||||||
|
system_prompt = system_prompt + ("\n\n" + manifest_block if manifest_block else "")
|
||||||
|
|
||||||
|
async for event in _run_single_agent_stream(
|
||||||
|
user_id=user_id,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
message=research_message,
|
||||||
|
context=prepared_context,
|
||||||
|
max_steps=12,
|
||||||
|
langfuse_prompt=langfuse_prompt,
|
||||||
|
agent_name="task-brief-agent",
|
||||||
|
tools=tools,
|
||||||
|
conversation_history=None,
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
async def update_core_memory(user_id: str, key: str, value: str) -> None:
|
async def update_core_memory(user_id: str, key: str, value: str) -> None:
|
||||||
"""Compatibility helper kept for callers that expect explicit memory update API."""
|
"""Compatibility helper kept for callers that expect explicit memory update API."""
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
|
|||||||
183
app/core/folder_indexer.py
Normal file
183
app/core/folder_indexer.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""Per-file summarisation for project folder integration."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from pypdf import PdfReader
|
||||||
|
from docx import Document as DocxDocument
|
||||||
|
|
||||||
|
from app.core.langfuse_client import (
|
||||||
|
compile_prompt,
|
||||||
|
extract_usage,
|
||||||
|
get_langfuse,
|
||||||
|
get_prompt_or_fallback,
|
||||||
|
)
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
|
||||||
|
_TEXT_FALLBACK = (
|
||||||
|
"You are summarising a file for an AI assistant that helps the user manage a project.\n"
|
||||||
|
"Produce a single sentence (<=30 words, <=200 chars) that captures the file's purpose "
|
||||||
|
"and most important detail.\nFile extension: {ext}\nFile name: {name}\nContent (truncated if long):\n{content}"
|
||||||
|
)
|
||||||
|
_IMAGE_FALLBACK = (
|
||||||
|
"You are summarising an image attached to a project folder.\n"
|
||||||
|
"Produce a single sentence (<=30 words, <=200 chars) describing what the image shows "
|
||||||
|
"and any obvious purpose (logo, screenshot, diagram, photo of a whiteboard, etc.)."
|
||||||
|
)
|
||||||
|
_MAX_INPUT_CHARS = 6000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IndexResult:
|
||||||
|
summary: str
|
||||||
|
tokens_used: int
|
||||||
|
|
||||||
|
|
||||||
|
async def _llm_text(messages: list) -> object:
|
||||||
|
"""Make the LLM call for text summarisation.
|
||||||
|
|
||||||
|
Defined as a standalone async function so tests can patch it cleanly
|
||||||
|
without needing to mock the LLM object itself.
|
||||||
|
"""
|
||||||
|
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
|
||||||
|
return await llm.ainvoke(messages)
|
||||||
|
|
||||||
|
|
||||||
|
async def _llm_vision(messages: list) -> object:
|
||||||
|
"""Make the LLM call for vision (image) summarisation.
|
||||||
|
|
||||||
|
Accepts the message list and returns the response directly, mirroring
|
||||||
|
the ``_llm_text`` caller pattern so tests can patch it at the module level.
|
||||||
|
"""
|
||||||
|
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
|
||||||
|
return await llm.ainvoke(messages)
|
||||||
|
|
||||||
|
|
||||||
|
async def summarize_image(*, image_b64: str, mime: str, file_name: str | None = None) -> IndexResult:
|
||||||
|
"""Return a compact summary of an image file using vision.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
image_b64:
|
||||||
|
Base64-encoded image bytes.
|
||||||
|
mime:
|
||||||
|
MIME type of the image, e.g. ``"image/png"``.
|
||||||
|
file_name:
|
||||||
|
Optional file name, attached to the Langfuse trace as input metadata.
|
||||||
|
"""
|
||||||
|
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_image", _IMAGE_FALLBACK)
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=template),
|
||||||
|
HumanMessage(content=[
|
||||||
|
{"type": "text", "text": "Summarise this image."},
|
||||||
|
{"type": "image_url", "image_url": {"url": f"data:{mime};base64,{image_b64}"}},
|
||||||
|
]),
|
||||||
|
]
|
||||||
|
lf = get_langfuse()
|
||||||
|
if lf is not None:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="folder-summarize-image",
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
prompt=prompt_obj,
|
||||||
|
input={"file_name": file_name, "mime": mime},
|
||||||
|
) as gen:
|
||||||
|
response = await _llm_vision(messages)
|
||||||
|
usage = extract_usage(response)
|
||||||
|
gen.update(output=response.content, usage_details=usage)
|
||||||
|
else:
|
||||||
|
response = await _llm_vision(messages)
|
||||||
|
usage = extract_usage(response)
|
||||||
|
summary = (response.content or "").strip()[:500]
|
||||||
|
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
|
||||||
|
|
||||||
|
|
||||||
|
async def summarize_text(*, content: str, ext: str, name: str) -> IndexResult:
|
||||||
|
"""Return a compact summary of a text file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
content:
|
||||||
|
Raw text content of the file (will be truncated to _MAX_INPUT_CHARS).
|
||||||
|
ext:
|
||||||
|
File extension including the leading dot, e.g. ``".md"``.
|
||||||
|
name:
|
||||||
|
File name, e.g. ``"kickoff.md"``.
|
||||||
|
"""
|
||||||
|
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_text", _TEXT_FALLBACK)
|
||||||
|
truncated = content[:_MAX_INPUT_CHARS]
|
||||||
|
compiled = compile_prompt(template, prompt_obj, ext=ext, name=name, content=truncated)
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=compiled),
|
||||||
|
HumanMessage(content="Summarise this file."),
|
||||||
|
]
|
||||||
|
lf = get_langfuse()
|
||||||
|
if lf is not None:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="folder-summarize-text",
|
||||||
|
model="gpt-4o-mini",
|
||||||
|
prompt=prompt_obj,
|
||||||
|
input={"file_name": name, "ext": ext, "content_chars": len(truncated)},
|
||||||
|
) as gen:
|
||||||
|
response = await _llm_text(messages)
|
||||||
|
usage = extract_usage(response)
|
||||||
|
gen.update(output=response.content, usage_details=usage)
|
||||||
|
else:
|
||||||
|
response = await _llm_text(messages)
|
||||||
|
usage = extract_usage(response)
|
||||||
|
summary = (response.content or "").strip()[:500]
|
||||||
|
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_pdf_text(pdf_b64: str) -> str:
|
||||||
|
buf = io.BytesIO(base64.b64decode(pdf_b64))
|
||||||
|
reader = PdfReader(buf)
|
||||||
|
parts: list[str] = []
|
||||||
|
for page in reader.pages:
|
||||||
|
try:
|
||||||
|
parts.append(page.extract_text() or "")
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
return "\n".join(parts).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_docx_text(docx_b64: str) -> str:
|
||||||
|
buf = io.BytesIO(base64.b64decode(docx_b64))
|
||||||
|
doc = DocxDocument(buf)
|
||||||
|
return "\n".join(p.text for p in doc.paragraphs if p.text).strip()
|
||||||
|
|
||||||
|
|
||||||
|
async def summarize_pdf(*, pdf_b64: str, name: str) -> IndexResult:
|
||||||
|
"""Return a compact summary of a PDF file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pdf_b64:
|
||||||
|
Base64-encoded PDF bytes.
|
||||||
|
name:
|
||||||
|
File name, e.g. ``"report.pdf"``.
|
||||||
|
"""
|
||||||
|
text = _extract_pdf_text(pdf_b64)
|
||||||
|
if not text:
|
||||||
|
return IndexResult(summary="Could not extract text", tokens_used=0)
|
||||||
|
return await summarize_text(content=text, ext=".pdf", name=name)
|
||||||
|
|
||||||
|
|
||||||
|
async def summarize_docx(*, docx_b64: str, name: str) -> IndexResult:
|
||||||
|
"""Return a compact summary of a DOCX file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
docx_b64:
|
||||||
|
Base64-encoded DOCX bytes.
|
||||||
|
name:
|
||||||
|
File name, e.g. ``"spec.docx"``.
|
||||||
|
"""
|
||||||
|
text = _extract_docx_text(docx_b64)
|
||||||
|
if not text:
|
||||||
|
return IndexResult(summary="Could not extract text", tokens_used=0)
|
||||||
|
return await summarize_text(content=text, ext=".docx", name=name)
|
||||||
@@ -51,6 +51,10 @@ def _api_key_for_model(model: str) -> str | None:
|
|||||||
return settings.GOOGLE_API_KEY or None
|
return settings.GOOGLE_API_KEY or None
|
||||||
if model.startswith("cerebras/"):
|
if model.startswith("cerebras/"):
|
||||||
return settings.CEREBRAS_API_KEY or None
|
return settings.CEREBRAS_API_KEY or None
|
||||||
|
if model.startswith("groq/"):
|
||||||
|
return settings.GROQ_API_KEY or None
|
||||||
|
if model.startswith("deepseek/"):
|
||||||
|
return settings.DEEPSEEK_API_KEY or None
|
||||||
if model.startswith("github_copilot/"):
|
if model.startswith("github_copilot/"):
|
||||||
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
|
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
|
||||||
# No API key is required; returning None lets LiteLLM handle auth.
|
# No API key is required; returning None lets LiteLLM handle auth.
|
||||||
@@ -102,10 +106,13 @@ _AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
|||||||
"floating-agent": lambda: settings.LLM_MODEL_FLOATING_AGENT or settings.LLM_MODEL,
|
"floating-agent": lambda: settings.LLM_MODEL_FLOATING_AGENT or settings.LLM_MODEL,
|
||||||
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
||||||
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
||||||
|
"brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL,
|
||||||
|
"task-brief-agent": lambda: settings.LLM_MODEL_TASK_BRIEF_AGENT or settings.LLM_MODEL,
|
||||||
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
||||||
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
|
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
|
||||||
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
|
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
|
||||||
"memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL,
|
"memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL,
|
||||||
|
"note-summarizer": lambda: "gpt-4o-mini",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
51
app/core/note_summarizer.py
Normal file
51
app/core/note_summarizer.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""Note summarizer — generates a compact AI summary for a note.
|
||||||
|
|
||||||
|
Called fire-and-forget from create_note / update_note tools so the
|
||||||
|
``notes.ai_summary`` column stays current without blocking the agent loop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
from app.core.langfuse_client import get_prompt_or_fallback
|
||||||
|
from app.core.llm import get_agent_llm
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_FALLBACK_PROMPT = """\
|
||||||
|
Summarize this note in <=250 characters. Be terse and dense.
|
||||||
|
Keep proper nouns, dates, decisions, and action items.
|
||||||
|
Do not start with "This note".
|
||||||
|
Respond with the summary text only — no intro, no labels.
|
||||||
|
|
||||||
|
Title: {title}
|
||||||
|
Content: {content}"""
|
||||||
|
|
||||||
|
_MAX_CONTENT_CHARS = 4000
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_note_summary(title: str, content: str) -> str:
|
||||||
|
"""Return a <=250-char summary of *title* + *content*.
|
||||||
|
|
||||||
|
Uses the Langfuse ``note_summary`` prompt (hot-swappable) with a local
|
||||||
|
fallback. Truncates *content* to 4000 chars before sending to avoid
|
||||||
|
token waste on large notes.
|
||||||
|
"""
|
||||||
|
template, _ = get_prompt_or_fallback("note_summary", _FALLBACK_PROMPT)
|
||||||
|
trimmed = content[:_MAX_CONTENT_CHARS]
|
||||||
|
system_prompt = template.format(title=title, content=trimmed)
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = get_agent_llm("note-summarizer")
|
||||||
|
response = await llm.ainvoke([
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(content="Generate the summary."),
|
||||||
|
])
|
||||||
|
text = response.content if isinstance(response.content, str) else ""
|
||||||
|
return text.strip()[:250]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("note_summarizer: failed to generate summary: %s", exc)
|
||||||
|
return ""
|
||||||
@@ -2,11 +2,35 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
||||||
|
|
||||||
|
# Matches <canvas kind="...">...</canvas> blocks (single-line or multiline).
|
||||||
|
_CANVAS_BLOCK_RE = re.compile(
|
||||||
|
r'<canvas\s+kind=["\']([^"\']+)["\']>(.*?)</canvas>',
|
||||||
|
re.DOTALL | re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_canvas_block(text: str) -> tuple[str, str | None, str | None]:
|
||||||
|
"""Strip the first <canvas kind="...">...</canvas> block from *text*.
|
||||||
|
|
||||||
|
Returns ``(visible_text, canvas_content, canvas_kind)``.
|
||||||
|
``canvas_content`` and ``canvas_kind`` are ``None`` when no block is found.
|
||||||
|
"""
|
||||||
|
match = _CANVAS_BLOCK_RE.search(text)
|
||||||
|
if not match:
|
||||||
|
return text, None, None
|
||||||
|
|
||||||
|
canvas_kind = match.group(1).strip()
|
||||||
|
canvas_content = match.group(2).strip()
|
||||||
|
visible = text[: match.start()] + text[match.end() :]
|
||||||
|
visible = visible.strip()
|
||||||
|
return visible, canvas_content, canvas_kind
|
||||||
|
|
||||||
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,10 +7,32 @@ The callback sends a `tool_call` WS frame and awaits the `tool_result`.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import Any, Callable, Coroutine
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
_SNAKE_TO_CAMEL_RE = re.compile(r"_([a-z])")
|
||||||
|
|
||||||
|
|
||||||
|
def _key_to_camel(key: str) -> str:
|
||||||
|
return _SNAKE_TO_CAMEL_RE.sub(lambda m: m.group(1).upper(), key)
|
||||||
|
|
||||||
|
|
||||||
|
def _keys_to_camel(obj: Any) -> Any:
|
||||||
|
"""Recursively convert dict keys from snake_case to camelCase.
|
||||||
|
|
||||||
|
Mirrors the JS-side ``toCamelCase`` applied to incoming WS frames in
|
||||||
|
``adiuvAI/src/main/api/backend-client.ts``. The Electron executor wraps
|
||||||
|
tool_result payloads in ``toSnakeCase`` before sending; this restores the
|
||||||
|
camelCase schema property names that the tool code expects to read.
|
||||||
|
"""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {_key_to_camel(k): _keys_to_camel(v) for k, v in obj.items()}
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [_keys_to_camel(v) for v in obj]
|
||||||
|
return obj
|
||||||
|
|
||||||
# Holds the execute callback for the current WS session.
|
# Holds the execute callback for the current WS session.
|
||||||
# Set by the chat WS handler before the orchestrator runs; cleared after.
|
# Set by the chat WS handler before the orchestrator runs; cleared after.
|
||||||
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
||||||
@@ -82,6 +104,7 @@ async def execute_on_client(
|
|||||||
payload["limit"] = limit
|
payload["limit"] = limit
|
||||||
|
|
||||||
result = await callback(payload)
|
result = await callback(payload)
|
||||||
|
result = _keys_to_camel(result)
|
||||||
collector = _tool_result_collector.get(None)
|
collector = _tool_result_collector.get(None)
|
||||||
if collector is not None:
|
if collector is not None:
|
||||||
collector.append({
|
collector.append({
|
||||||
|
|||||||
@@ -4,6 +4,10 @@ import logging
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
||||||
|
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
@@ -11,10 +15,6 @@ logging.basicConfig(
|
|||||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||||
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||||
|
|
||||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
|
||||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
|
||||||
from app.config.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
async def _memory_audit_cron_tick() -> None:
|
async def _memory_audit_cron_tick() -> None:
|
||||||
"""Weekly cron: contradiction scan + label canonicalization for all users (Phase 7)."""
|
"""Weekly cron: contradiction scan + label canonicalization for all users (Phase 7)."""
|
||||||
|
|||||||
@@ -243,6 +243,7 @@ class AgentRunLog(Base):
|
|||||||
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
|
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
|
||||||
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
|
||||||
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||||
started_at: Mapped[datetime] = mapped_column(
|
started_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
@@ -263,6 +264,17 @@ class AgentRunLog(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MonthlyTokenUsage(Base):
|
||||||
|
__tablename__ = "monthly_token_usage"
|
||||||
|
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
|
||||||
|
)
|
||||||
|
year_month: Mapped[str] = mapped_column(String(7), primary_key=True) # 'YYYY-MM'
|
||||||
|
feature: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
|
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
|
||||||
|
|
||||||
|
|
||||||
# ── Memory models ─────────────────────────────────────────────────────────────
|
# ── Memory models ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -85,6 +85,17 @@ class WsFrameType(str, Enum):
|
|||||||
journey_start = "journey_start"
|
journey_start = "journey_start"
|
||||||
journey_message = "journey_message"
|
journey_message = "journey_message"
|
||||||
journey_reply = "journey_reply"
|
journey_reply = "journey_reply"
|
||||||
|
# ── v5 brief frame types ──────────────────────────────────────────
|
||||||
|
brief_request = "brief_request"
|
||||||
|
# ── v6 task brief frame types ─────────────────────────────────────
|
||||||
|
task_brief_request = "task_brief_request"
|
||||||
|
# ── v7 folder index frame types ───────────────────────────────────
|
||||||
|
index_session_start = "index_session_start"
|
||||||
|
index_file_batch = "index_file_batch"
|
||||||
|
index_session_cancel = "index_session_cancel"
|
||||||
|
index_file_result = "index_file_result"
|
||||||
|
index_session_progress = "index_session_progress"
|
||||||
|
index_session_done = "index_session_done"
|
||||||
|
|
||||||
|
|
||||||
class WsToolCall(BaseModel):
|
class WsToolCall(BaseModel):
|
||||||
@@ -140,6 +151,16 @@ class WsDeviceHello(BaseModel):
|
|||||||
|
|
||||||
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
||||||
|
|
||||||
|
class FormatPrefsModel(BaseModel):
|
||||||
|
"""User display preferences sent by Electron on each request."""
|
||||||
|
|
||||||
|
timezone: str = "UTC"
|
||||||
|
date_format: str = "dd/MM/yyyy"
|
||||||
|
time_format: str = "24h"
|
||||||
|
locale: str = "en-US"
|
||||||
|
now_iso: str = ""
|
||||||
|
|
||||||
|
|
||||||
class WsFloatingScope(BaseModel):
|
class WsFloatingScope(BaseModel):
|
||||||
"""Scope for a floating request — narrows the agent to a specific entity."""
|
"""Scope for a floating request — narrows the agent to a specific entity."""
|
||||||
|
|
||||||
@@ -153,6 +174,7 @@ class WsHomeRequest(BaseModel):
|
|||||||
type: Literal[WsFrameType.home_request] = WsFrameType.home_request
|
type: Literal[WsFrameType.home_request] = WsFrameType.home_request
|
||||||
message: str
|
message: str
|
||||||
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
format_prefs: FormatPrefsModel | None = None
|
||||||
|
|
||||||
|
|
||||||
class WsFloatingRequest(BaseModel):
|
class WsFloatingRequest(BaseModel):
|
||||||
@@ -161,6 +183,18 @@ class WsFloatingRequest(BaseModel):
|
|||||||
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
|
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
|
||||||
message: str
|
message: str
|
||||||
scope: WsFloatingScope
|
scope: WsFloatingScope
|
||||||
|
format_prefs: FormatPrefsModel | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WsBriefRequest(BaseModel):
|
||||||
|
"""Client → Server: Request a plain-text brief (home or project)."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.brief_request] = WsFrameType.brief_request
|
||||||
|
request_id: str | None = None
|
||||||
|
session_id: str | None = None
|
||||||
|
mode: Literal["home", "project"]
|
||||||
|
project_id: str | None = None
|
||||||
|
format_prefs: FormatPrefsModel | None = None
|
||||||
|
|
||||||
|
|
||||||
class WsStreamStart(BaseModel):
|
class WsStreamStart(BaseModel):
|
||||||
@@ -183,6 +217,8 @@ class WsStreamEnd(BaseModel):
|
|||||||
|
|
||||||
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||||
request_id: str
|
request_id: str
|
||||||
|
error: str | None = None
|
||||||
|
mutations: list[dict[str, Any]] | None = None
|
||||||
|
|
||||||
|
|
||||||
class WsDomain(BaseModel):
|
class WsDomain(BaseModel):
|
||||||
|
|||||||
@@ -33,9 +33,11 @@ google-auth-httplib2>=0.2.0
|
|||||||
msal>=1.28.0
|
msal>=1.28.0
|
||||||
cryptography>=42.0.0
|
cryptography>=42.0.0
|
||||||
pgvector>=0.2.5
|
pgvector>=0.2.5
|
||||||
langfuse>=2.0.0
|
langfuse>=3.3.1
|
||||||
beautifulsoup4>=4.12.0
|
beautifulsoup4>=4.12.0
|
||||||
lxml>=5.0.0
|
lxml>=5.0.0
|
||||||
PyYAML>=6.0.0
|
PyYAML>=6.0.0
|
||||||
apscheduler>=3.10.0
|
apscheduler>=3.10.0
|
||||||
ruff>=0.8.0
|
ruff>=0.8.0
|
||||||
|
pypdf>=4.0
|
||||||
|
python-docx>=1.1
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ from jose import jwt
|
|||||||
from sqlalchemy import StaticPool, event
|
from sqlalchemy import StaticPool, event
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.db import Base, get_session
|
from app.db import Base, get_session
|
||||||
from app.main import app
|
from app.main import app
|
||||||
@@ -134,6 +136,38 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st
|
|||||||
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Convenience aliases and per-tier user fixtures ────────────────────
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db(db_session: AsyncSession) -> AsyncSession:
|
||||||
|
"""Alias for db_session — used by folder quota tests."""
|
||||||
|
return db_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_user_free(db_session: AsyncSession):
|
||||||
|
"""Return the seeded free-tier User row."""
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(User).where(User.id == TEST_USER_IDS["free"])
|
||||||
|
)
|
||||||
|
return result.scalar_one()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def test_user_power(db_session: AsyncSession):
|
||||||
|
"""Return the seeded power-tier User row."""
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(User).where(User.id == TEST_USER_IDS["power"])
|
||||||
|
)
|
||||||
|
return result.scalar_one()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def auth_headers_free() -> dict[str, str]:
|
||||||
|
"""Authorization header for the seeded free-tier user."""
|
||||||
|
return auth_header("free")
|
||||||
|
|
||||||
|
|
||||||
# ── CLI options ───────────────────────────────────────────────────────
|
# ── CLI options ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
def pytest_addoption(parser):
|
||||||
|
|||||||
@@ -382,7 +382,6 @@ async def test_eval_runner(runner_case, pytestconfig):
|
|||||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
await run_local_agent(_USER_ID, config, run_log, mgr)
|
||||||
|
|
||||||
_, kwargs = mock_fin.call_args
|
_, kwargs = mock_fin.call_args
|
||||||
inserts = [c for c in calls if c["action"] == "insert"]
|
|
||||||
score, comment = _evaluate_case(case, calls, kwargs)
|
score, comment = _evaluate_case(case, calls, kwargs)
|
||||||
|
|
||||||
if obs is not None:
|
if obs is not None:
|
||||||
|
|||||||
163
tests/test_brief_agent.py
Normal file
163
tests/test_brief_agent.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
"""Tests for Phase 3: brief agent WS frame + REST fallback.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
- run_home_brief streams non-empty text (mocked _run_single_agent_stream)
|
||||||
|
- run_project_brief with bogus UUID → WS returns stream_end with error, no crash
|
||||||
|
- _build_read_tools uses read-only subset only (no mutating tools)
|
||||||
|
- POST /chat/brief home mode returns {response: "..."}
|
||||||
|
- POST /chat/brief project mode with invalid UUID → 422
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tests.conftest import TEST_USER_IDS, auth_header
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_USER_ID = TEST_USER_IDS["pro"]
|
||||||
|
_EMPTY_CONTEXT: dict[str, Any] = {"core_memory": {}}
|
||||||
|
|
||||||
|
|
||||||
|
async def _fake_token_stream(*_args, **_kwargs) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
"""Fake _run_single_agent_stream that yields two token events."""
|
||||||
|
yield ("token", "Hello")
|
||||||
|
yield ("token", " world")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Unit: run_home_brief streams non-empty text
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_home_brief_streams_text():
|
||||||
|
with patch(
|
||||||
|
"app.core.brief_agent._run_single_agent_stream",
|
||||||
|
side_effect=_fake_token_stream,
|
||||||
|
):
|
||||||
|
from app.core.brief_agent import run_home_brief
|
||||||
|
|
||||||
|
chunks: list[str] = []
|
||||||
|
async for event_type, data in run_home_brief(_USER_ID, _EMPTY_CONTEXT):
|
||||||
|
if event_type == "token":
|
||||||
|
chunks.append(str(data))
|
||||||
|
|
||||||
|
assert "".join(chunks) == "Hello world"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Unit: run_project_brief streams text with valid UUID
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_project_brief_streams_text():
|
||||||
|
project_id = str(uuid.uuid4())
|
||||||
|
with patch(
|
||||||
|
"app.core.brief_agent._run_single_agent_stream",
|
||||||
|
side_effect=_fake_token_stream,
|
||||||
|
):
|
||||||
|
from app.core.brief_agent import run_project_brief
|
||||||
|
|
||||||
|
chunks: list[str] = []
|
||||||
|
async for event_type, data in run_project_brief(_USER_ID, project_id, _EMPTY_CONTEXT):
|
||||||
|
if event_type == "token":
|
||||||
|
chunks.append(str(data))
|
||||||
|
|
||||||
|
assert "".join(chunks) == "Hello world"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Unit: _build_read_tools uses read-only subset (no write tools)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_build_read_tools_read_only_subset():
|
||||||
|
from app.agents.note_agent import NOTE_READ_TOOLS
|
||||||
|
from app.agents.project_agent import PROJECT_READ_TOOLS
|
||||||
|
from app.agents.task_agent import TASK_READ_TOOLS
|
||||||
|
from app.agents.timeline_agent import TIMELINE_READ_TOOLS
|
||||||
|
from app.core.brief_agent import _build_read_tools
|
||||||
|
|
||||||
|
tools = _build_read_tools(_USER_ID, None)
|
||||||
|
tool_names = {getattr(t, "name", None) or getattr(t, "__name__", str(t)) for t in tools}
|
||||||
|
|
||||||
|
# Read-only exports must be present.
|
||||||
|
for read_list in (TASK_READ_TOOLS, PROJECT_READ_TOOLS, TIMELINE_READ_TOOLS, NOTE_READ_TOOLS):
|
||||||
|
for t in read_list:
|
||||||
|
name = getattr(t, "name", None) or getattr(t, "__name__", str(t))
|
||||||
|
assert name in tool_names, f"Read tool {name!r} missing from _build_read_tools"
|
||||||
|
|
||||||
|
# No mutating tools (e.g. create_task, update_task, delete_task).
|
||||||
|
mutating = {"create_task", "update_task", "delete_task", "create_project",
|
||||||
|
"update_project", "delete_project", "create_note", "update_note",
|
||||||
|
"delete_note", "memory_add", "memory_update", "memory_delete"}
|
||||||
|
overlap = tool_names & mutating
|
||||||
|
assert not overlap, f"Mutating tools in brief read-only subset: {overlap}"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration: POST /chat/brief — home mode
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rest_brief_home_returns_response(client):
|
||||||
|
async def _fake_home_brief(user_id, context):
|
||||||
|
yield ("token", "Today looks light.")
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.api.routes.chat.run_home_brief", side_effect=_fake_home_brief),
|
||||||
|
patch(
|
||||||
|
"app.api.routes.chat.MemoryMiddleware.enrich_context",
|
||||||
|
new=AsyncMock(return_value={}),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
res = client.post(
|
||||||
|
"/api/v1/chat/brief",
|
||||||
|
json={"mode": "home"},
|
||||||
|
headers=auth_header("pro"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
data = res.json()
|
||||||
|
assert data["response"] == "Today looks light."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rest_brief_project_invalid_uuid_returns_422(client):
|
||||||
|
res = client.post(
|
||||||
|
"/api/v1/chat/brief",
|
||||||
|
json={"mode": "project", "project_id": "not-a-uuid"},
|
||||||
|
headers=auth_header("pro"),
|
||||||
|
)
|
||||||
|
assert res.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rest_brief_project_missing_uuid_returns_422(client):
|
||||||
|
res = client.post(
|
||||||
|
"/api/v1/chat/brief",
|
||||||
|
json={"mode": "project"},
|
||||||
|
headers=auth_header("pro"),
|
||||||
|
)
|
||||||
|
assert res.status_code == 422
|
||||||
@@ -10,8 +10,11 @@ import pytest
|
|||||||
from langchain_core.messages import AIMessage, ToolMessage
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
|
||||||
from app.core.deep_agent import (
|
from app.core.deep_agent import (
|
||||||
|
_build_system_prompt,
|
||||||
|
_datetime_context_injection,
|
||||||
_infer_floating_domain,
|
_infer_floating_domain,
|
||||||
_normalize_tagged_list_lines,
|
_normalize_tagged_list_lines,
|
||||||
|
_request_context_block,
|
||||||
run_floating,
|
run_floating,
|
||||||
run_floating_stream,
|
run_floating_stream,
|
||||||
run_home,
|
run_home,
|
||||||
@@ -91,8 +94,12 @@ async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_res
|
|||||||
"floating_domain",
|
"floating_domain",
|
||||||
{"type": "timeline", "id": "tl-1", "section": None},
|
{"type": "timeline", "id": "tl-1", "section": None},
|
||||||
)
|
)
|
||||||
assert ("token", "stream-") in events
|
# _run_single_agent_stream uses ainvoke (not astream); the final token is
|
||||||
assert ("token", "ok") in events
|
# the second LLM response which echoes the tool result.
|
||||||
|
token_events = [e for e in events if e[0] == "token"]
|
||||||
|
assert token_events, "Expected at least one token event"
|
||||||
|
combined = "".join(str(e[1]) for e in token_events)
|
||||||
|
assert "Mock Task" in combined
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -286,3 +293,213 @@ async def test_run_floating_stream_returns_fallback_when_sanitization_would_empt
|
|||||||
events.append(event)
|
events.append(event)
|
||||||
|
|
||||||
assert ("token", "No results found.") in events
|
assert ("token", "No results found.") in events
|
||||||
|
|
||||||
|
|
||||||
|
# ── _datetime_context_injection ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _fp(tz: str, now_iso: str) -> dict:
|
||||||
|
return {"timezone": tz, "now_iso": now_iso, "date_format": "dd/MM/yyyy", "time_format": "24h"}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_ms(block: str, key: str) -> tuple[int, int]:
|
||||||
|
"""Extract [start, end] from a 'key [start, end]' line in the DATE CONTEXT block."""
|
||||||
|
import re
|
||||||
|
m = re.search(rf"^{key}\s+\[(\d+),\s*(\d+)\]", block, re.MULTILINE)
|
||||||
|
assert m, f"Key '{key}' not found in block:\n{block}"
|
||||||
|
return int(m.group(1)), int(m.group(2))
|
||||||
|
|
||||||
|
|
||||||
|
def test_datetime_context_injection_europe_rome_late_evening():
|
||||||
|
"""22:16 CEST on 2026-04-26 — 'tomorrow' must be 2026-04-27 00:00→23:59:59.999 CEST."""
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
block = _datetime_context_injection({"format_prefs": _fp("Europe/Rome", "2026-04-26T20:16:02.155Z")})
|
||||||
|
assert "DATE CONTEXT" in block
|
||||||
|
assert "Europe/Rome" in block
|
||||||
|
|
||||||
|
tz = ZoneInfo("Europe/Rome")
|
||||||
|
today_start = int(datetime(2026, 4, 26, tzinfo=tz).timestamp() * 1000)
|
||||||
|
today_end = int(datetime(2026, 4, 27, tzinfo=tz).timestamp() * 1000) - 1
|
||||||
|
tomorrow_start = today_end + 1
|
||||||
|
tomorrow_end = int(datetime(2026, 4, 28, tzinfo=tz).timestamp() * 1000) - 1
|
||||||
|
|
||||||
|
t_s, t_e = _parse_ms(block, "today")
|
||||||
|
assert t_s == today_start
|
||||||
|
assert t_e == today_end
|
||||||
|
|
||||||
|
tm_s, tm_e = _parse_ms(block, "tomorrow")
|
||||||
|
assert tm_s == tomorrow_start
|
||||||
|
assert tm_e == tomorrow_end
|
||||||
|
|
||||||
|
# Sanity: window is exactly 86 400 000 ms (1 day, CEST has no DST jump on this date)
|
||||||
|
assert today_end - today_start + 1 == 86_400_000
|
||||||
|
assert tomorrow_end - tomorrow_start + 1 == 86_400_000
|
||||||
|
|
||||||
|
|
||||||
|
def test_datetime_context_injection_utc():
|
||||||
|
"""UTC timezone: boundaries are clean UTC midnights."""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
block = _datetime_context_injection({"format_prefs": _fp("UTC", "2026-01-15T10:00:00Z")})
|
||||||
|
t_s, t_e = _parse_ms(block, "today")
|
||||||
|
expected_start = int(datetime(2026, 1, 15, tzinfo=timezone.utc).timestamp() * 1000)
|
||||||
|
assert t_s == expected_start
|
||||||
|
assert t_e == expected_start + 86_400_000 - 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_datetime_context_injection_dst_spring_forward():
|
||||||
|
"""Europe/Rome DST spring-forward 2026-03-29: that day is 23h, not 24h."""
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
block = _datetime_context_injection({"format_prefs": _fp("Europe/Rome", "2026-03-29T08:00:00Z")})
|
||||||
|
tz = ZoneInfo("Europe/Rome")
|
||||||
|
day_start = int(datetime(2026, 3, 29, tzinfo=tz).timestamp() * 1000)
|
||||||
|
day_end = int(datetime(2026, 3, 30, tzinfo=tz).timestamp() * 1000) - 1
|
||||||
|
|
||||||
|
t_s, t_e = _parse_ms(block, "today")
|
||||||
|
assert t_s == day_start
|
||||||
|
assert t_e == day_end
|
||||||
|
assert t_e - t_s + 1 == 23 * 3_600_000 # 23-hour day
|
||||||
|
|
||||||
|
|
||||||
|
def test_datetime_context_injection_dst_fall_back():
|
||||||
|
"""Europe/Rome DST fall-back 2026-10-25: that day is 25h."""
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
block = _datetime_context_injection({"format_prefs": _fp("Europe/Rome", "2026-10-25T08:00:00Z")})
|
||||||
|
tz = ZoneInfo("Europe/Rome")
|
||||||
|
day_start = int(datetime(2026, 10, 25, tzinfo=tz).timestamp() * 1000)
|
||||||
|
day_end = int(datetime(2026, 10, 26, tzinfo=tz).timestamp() * 1000) - 1
|
||||||
|
|
||||||
|
t_s, t_e = _parse_ms(block, "today")
|
||||||
|
assert t_s == day_start
|
||||||
|
assert t_e == day_end
|
||||||
|
assert t_e - t_s + 1 == 25 * 3_600_000 # 25-hour day
|
||||||
|
|
||||||
|
|
||||||
|
def test_datetime_context_injection_year_boundary():
|
||||||
|
"""Dec 31 → Jan 1: last_year, this_year, next_month cross year boundary correctly."""
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
block = _datetime_context_injection({"format_prefs": _fp("UTC", "2026-12-31T23:00:00Z")})
|
||||||
|
tz = ZoneInfo("UTC")
|
||||||
|
|
||||||
|
yr_s, yr_e = _parse_ms(block, "this_year")
|
||||||
|
assert yr_s == int(datetime(2026, 1, 1, tzinfo=tz).timestamp() * 1000)
|
||||||
|
assert yr_e == int(datetime(2027, 1, 1, tzinfo=tz).timestamp() * 1000) - 1
|
||||||
|
|
||||||
|
ly_s, ly_e = _parse_ms(block, "last_year")
|
||||||
|
assert ly_s == int(datetime(2025, 1, 1, tzinfo=tz).timestamp() * 1000)
|
||||||
|
assert ly_e == yr_s - 1
|
||||||
|
|
||||||
|
nm_s, _ = _parse_ms(block, "next_month")
|
||||||
|
assert nm_s == int(datetime(2027, 1, 1, tzinfo=tz).timestamp() * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
def test_datetime_context_injection_missing_format_prefs():
|
||||||
|
assert _datetime_context_injection({}) == ""
|
||||||
|
assert _datetime_context_injection({"format_prefs": None}) == ""
|
||||||
|
assert _datetime_context_injection({"format_prefs": "bad"}) == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ── _request_context_block ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_request_context_block_scope_and_project():
|
||||||
|
ctx = {"scope": {"type": "task", "id": "t-1"}, "resolved_project_id": "proj-uuid"}
|
||||||
|
block = _request_context_block(ctx)
|
||||||
|
assert "scope" in block
|
||||||
|
assert "resolved_project_id: proj-uuid" in block
|
||||||
|
|
||||||
|
|
||||||
|
def test_request_context_block_empty():
|
||||||
|
assert _request_context_block({}) == ""
|
||||||
|
assert _request_context_block({"scope": None}) == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ── _build_system_prompt ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_build_system_prompt_substitutes_all_slots(monkeypatch):
|
||||||
|
"""All five slots must appear in the compiled output; no raw placeholder remains."""
|
||||||
|
# Patch get_prompt_or_fallback to return None prompt_obj so we use fallback .format() path
|
||||||
|
import app.core.deep_agent as da
|
||||||
|
monkeypatch.setattr(da, "get_prompt_or_fallback", lambda name, fallback: (fallback, None))
|
||||||
|
|
||||||
|
ctx = {
|
||||||
|
"format_prefs": _fp("Europe/Rome", "2026-04-26T20:16:02.155Z"),
|
||||||
|
"core_memory": {"language": "it"},
|
||||||
|
"relational_memory": ["Alice — client"],
|
||||||
|
"proactive_hints": ["User prefers morning meetings"],
|
||||||
|
"scope": {"type": "task"},
|
||||||
|
"resolved_project_id": "proj-1",
|
||||||
|
}
|
||||||
|
from app.core.deep_agent import _HOME_SYSTEM_PROMPT
|
||||||
|
text, _ = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, ctx)
|
||||||
|
|
||||||
|
# No unresolved placeholders
|
||||||
|
assert "{date_context}" not in text
|
||||||
|
assert "{language_instruction}" not in text
|
||||||
|
assert "{relational_memory}" not in text
|
||||||
|
assert "{proactive_hints}" not in text
|
||||||
|
assert "{request_context}" not in text
|
||||||
|
|
||||||
|
# Content was injected
|
||||||
|
assert "DATE CONTEXT" in text
|
||||||
|
assert "Italian" in text
|
||||||
|
assert "Alice" in text
|
||||||
|
assert "morning meetings" in text
|
||||||
|
assert "proj-1" in text
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_system_prompt_empty_format_prefs(monkeypatch):
|
||||||
|
"""Missing format_prefs must not raise — date_context slot renders empty string."""
|
||||||
|
import app.core.deep_agent as da
|
||||||
|
monkeypatch.setattr(da, "get_prompt_or_fallback", lambda name, fallback: (fallback, None))
|
||||||
|
|
||||||
|
from app.core.deep_agent import _HOME_SYSTEM_PROMPT
|
||||||
|
text, _ = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, {})
|
||||||
|
# Prompt renders without error; date section is empty but structure holds
|
||||||
|
assert "# Date filtering" in text
|
||||||
|
assert "{date_context}" not in text
|
||||||
|
|
||||||
|
|
||||||
|
def test_human_message_is_bare_message(monkeypatch):
|
||||||
|
"""After the refactor HumanMessage content must equal the raw user message exactly."""
|
||||||
|
import app.core.deep_agent as da
|
||||||
|
from langchain_core.messages import HumanMessage as LCHumanMessage
|
||||||
|
|
||||||
|
captured: list[list] = []
|
||||||
|
|
||||||
|
class _CaptureLLM:
|
||||||
|
def bind_tools(self, _):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def ainvoke(self, messages):
|
||||||
|
captured.append(list(messages))
|
||||||
|
return AIMessage(content="risposta")
|
||||||
|
|
||||||
|
monkeypatch.setattr(da, "get_prompt_or_fallback", lambda n, f: (f, None))
|
||||||
|
monkeypatch.setattr(da, "get_agent_llm", lambda _: _CaptureLLM())
|
||||||
|
monkeypatch.setattr(da, "_all_tools_for_user", lambda *_: [])
|
||||||
|
monkeypatch.setattr(da, "get_langfuse", lambda: None)
|
||||||
|
monkeypatch.setattr(da, "set_tool_result_collector", lambda _: None)
|
||||||
|
monkeypatch.setattr(da, "clear_tool_result_collector", lambda: None)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
chunks = []
|
||||||
|
ctx = {"format_prefs": _fp("UTC", "2026-04-27T10:00:00Z")}
|
||||||
|
async for ev in da.run_home_stream("u1", "Cosa devo fare domani?", ctx):
|
||||||
|
chunks.append(ev)
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(_run())
|
||||||
|
|
||||||
|
assert captured, "LLM was never called"
|
||||||
|
messages = captured[0]
|
||||||
|
human = next(m for m in messages if isinstance(m, LCHumanMessage))
|
||||||
|
assert human.content == "Cosa devo fare domani?"
|
||||||
|
assert "Context:" not in human.content
|
||||||
|
|||||||
@@ -201,7 +201,6 @@ def test_ws_device_invalid_first_frame_closes(client):
|
|||||||
def test_ws_device_tool_result_dispatched(client):
|
def test_ws_device_tool_result_dispatched(client):
|
||||||
"""tool_result frame is routed to the DeviceConnectionManager."""
|
"""tool_result frame is routed to the DeviceConnectionManager."""
|
||||||
token = make_jwt(tier="free")
|
token = make_jwt(tier="free")
|
||||||
user_id = TEST_USER_IDS["free"]
|
|
||||||
|
|
||||||
from app.core.device_manager import device_manager as dm
|
from app.core.device_manager import device_manager as dm
|
||||||
|
|
||||||
|
|||||||
139
tests/test_folder_agent_tool.py
Normal file
139
tests/test_folder_agent_tool.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agents.folder_agent import (
|
||||||
|
read_project_folder_file,
|
||||||
|
search_project_folder_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
async def test_happy_path():
|
||||||
|
with patch(
|
||||||
|
"app.agents.folder_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"content": "file body", "kind": "text", "totalSize": 9}),
|
||||||
|
):
|
||||||
|
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "docs/x.md"})
|
||||||
|
assert "file body" in out
|
||||||
|
assert "kind=text" in out
|
||||||
|
|
||||||
|
|
||||||
|
async def test_traversal_rejected():
|
||||||
|
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "../../etc/passwd"})
|
||||||
|
assert out == "Access denied"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_absolute_path_rejected():
|
||||||
|
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "C:\\Windows\\foo"})
|
||||||
|
assert out == "Access denied"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_missing_file():
|
||||||
|
with patch(
|
||||||
|
"app.agents.folder_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"content": "", "kind": "missing", "totalSize": 0}),
|
||||||
|
):
|
||||||
|
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "ghost.md"})
|
||||||
|
assert "not found" in out.lower()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pagination_signals_more_available():
|
||||||
|
# Electron returned the first slice, totalSize larger than slice length.
|
||||||
|
with patch(
|
||||||
|
"app.agents.folder_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"content": "first chunk", "kind": "text", "totalSize": 1000}),
|
||||||
|
):
|
||||||
|
out = await read_project_folder_file.ainvoke({
|
||||||
|
"project_id": "p1",
|
||||||
|
"relative_path": "big.txt",
|
||||||
|
"offset": 0,
|
||||||
|
"length": 11,
|
||||||
|
})
|
||||||
|
assert "first chunk" in out
|
||||||
|
assert "More content available" in out
|
||||||
|
assert "offset=11" in out
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pdf_extracted_then_sliced(monkeypatch):
|
||||||
|
from app.agents import folder_agent
|
||||||
|
monkeypatch.setattr(folder_agent, "_extract_pdf_text", lambda b: "ABC " * 100)
|
||||||
|
with patch(
|
||||||
|
"app.agents.folder_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"content": "JVBERi0xLg==", "kind": "pdf", "totalSize": 12}),
|
||||||
|
):
|
||||||
|
out = await read_project_folder_file.ainvoke({
|
||||||
|
"project_id": "p1",
|
||||||
|
"relative_path": "doc.pdf",
|
||||||
|
"offset": 0,
|
||||||
|
"length": 8,
|
||||||
|
})
|
||||||
|
assert "kind=pdf" in out
|
||||||
|
assert "ABC ABC " in out
|
||||||
|
assert "More content available" in out
|
||||||
|
|
||||||
|
|
||||||
|
async def test_image_returns_placeholder():
|
||||||
|
with patch(
|
||||||
|
"app.agents.folder_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"content": "iVBORw0K", "kind": "image", "totalSize": 1024}),
|
||||||
|
):
|
||||||
|
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "logo.png"})
|
||||||
|
assert "image" in out.lower()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_search_finds_match_with_context():
|
||||||
|
body = "alpha\nbeta\nthe needle is here\ngamma\ndelta"
|
||||||
|
with patch(
|
||||||
|
"app.agents.folder_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"content": body, "kind": "text", "totalSize": len(body)}),
|
||||||
|
):
|
||||||
|
out = await search_project_folder_file.ainvoke({
|
||||||
|
"project_id": "p1",
|
||||||
|
"relative_path": "log.txt",
|
||||||
|
"query": "needle",
|
||||||
|
"context_lines": 1,
|
||||||
|
})
|
||||||
|
assert "needle" in out
|
||||||
|
assert "matches=1" in out
|
||||||
|
# Context lines included
|
||||||
|
assert "beta" in out
|
||||||
|
assert "gamma" in out
|
||||||
|
|
||||||
|
|
||||||
|
async def test_search_no_match():
|
||||||
|
with patch(
|
||||||
|
"app.agents.folder_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"content": "nothing here", "kind": "text", "totalSize": 12}),
|
||||||
|
):
|
||||||
|
out = await search_project_folder_file.ainvoke({
|
||||||
|
"project_id": "p1",
|
||||||
|
"relative_path": "x.txt",
|
||||||
|
"query": "zzz",
|
||||||
|
})
|
||||||
|
assert "No matches" in out
|
||||||
|
|
||||||
|
|
||||||
|
async def test_search_rejects_traversal():
|
||||||
|
out = await search_project_folder_file.ainvoke({
|
||||||
|
"project_id": "p1",
|
||||||
|
"relative_path": "../etc/passwd",
|
||||||
|
"query": "root",
|
||||||
|
})
|
||||||
|
assert out == "Access denied"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_search_image_rejected():
|
||||||
|
with patch(
|
||||||
|
"app.agents.folder_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"content": "b64data", "kind": "image", "totalSize": 100}),
|
||||||
|
):
|
||||||
|
out = await search_project_folder_file.ainvoke({
|
||||||
|
"project_id": "p1",
|
||||||
|
"relative_path": "logo.png",
|
||||||
|
"query": "anything",
|
||||||
|
})
|
||||||
|
assert "Cannot search" in out
|
||||||
83
tests/test_folder_indexer.py
Normal file
83
tests/test_folder_indexer.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""Folder indexer LLM helpers."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.folder_indexer import summarize_text, summarize_image, IndexResult
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
async def test_summarize_text_returns_summary_and_tokens():
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.content = "Kickoff notes covering scope and deadlines."
|
||||||
|
mock_resp.usage_metadata = {"input_tokens": 320, "output_tokens": 18, "total_tokens": 338}
|
||||||
|
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
|
||||||
|
result = await summarize_text(content="hello world", ext=".md", name="kickoff.md")
|
||||||
|
assert isinstance(result, IndexResult)
|
||||||
|
assert result.summary == "Kickoff notes covering scope and deadlines."
|
||||||
|
assert result.tokens_used == 338
|
||||||
|
|
||||||
|
|
||||||
|
async def test_summarize_text_truncates_summary_at_500_chars():
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.content = "x" * 1000
|
||||||
|
mock_resp.usage_metadata = {"total_tokens": 100}
|
||||||
|
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
|
||||||
|
result = await summarize_text(content="x", ext=".md", name="x.md")
|
||||||
|
assert len(result.summary) <= 500
|
||||||
|
|
||||||
|
|
||||||
|
async def test_summarize_image_uses_vision_content_blocks():
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.content = "Final logo on white background."
|
||||||
|
mock_resp.usage_metadata = {"total_tokens": 500}
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
async def fake_llm_vision(messages):
|
||||||
|
captured["messages"] = messages
|
||||||
|
return mock_resp
|
||||||
|
|
||||||
|
with patch("app.core.folder_indexer._llm_vision", new=fake_llm_vision):
|
||||||
|
result = await summarize_image(image_b64="iVBORw0KG", mime="image/png")
|
||||||
|
|
||||||
|
assert "Final logo" in result.summary
|
||||||
|
assert result.tokens_used == 500
|
||||||
|
# last message contains an image content block
|
||||||
|
last = captured["messages"][-1]
|
||||||
|
assert any(
|
||||||
|
isinstance(p, dict) and p.get("type") == "image_url"
|
||||||
|
for p in (last.content if isinstance(last.content, list) else [])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_summarize_pdf_extracts_then_summarizes(monkeypatch):
|
||||||
|
# pypdf.PdfReader returns text from pages
|
||||||
|
from app.core import folder_indexer
|
||||||
|
class FakePage:
|
||||||
|
def extract_text(self): return "PDF page content with project info."
|
||||||
|
class FakeReader:
|
||||||
|
pages = [FakePage(), FakePage()]
|
||||||
|
monkeypatch.setattr(folder_indexer, "PdfReader", lambda buf: FakeReader())
|
||||||
|
mock_resp = AsyncMock(); mock_resp.content = "Project info doc."; mock_resp.usage_metadata = {"total_tokens": 50}
|
||||||
|
async def fake_llm(messages): return mock_resp
|
||||||
|
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
|
||||||
|
result = await folder_indexer.summarize_pdf(pdf_b64="SGVsbG8=", name="doc.pdf")
|
||||||
|
assert "Project info" in result.summary
|
||||||
|
assert result.tokens_used == 50
|
||||||
|
|
||||||
|
|
||||||
|
async def test_summarize_docx_extracts_then_summarizes(monkeypatch):
|
||||||
|
from app.core import folder_indexer
|
||||||
|
class FakePara:
|
||||||
|
def __init__(self, t): self.text = t
|
||||||
|
class FakeDoc:
|
||||||
|
paragraphs = [FakePara("Heading"), FakePara("Body paragraph one.")]
|
||||||
|
monkeypatch.setattr(folder_indexer, "DocxDocument", lambda buf: FakeDoc())
|
||||||
|
mock_resp = AsyncMock(); mock_resp.content = "Heading and body."; mock_resp.usage_metadata = {"total_tokens": 30}
|
||||||
|
async def fake_llm(messages): return mock_resp
|
||||||
|
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
|
||||||
|
result = await folder_indexer.summarize_docx(docx_b64="UEsDBBQ=", name="doc.docx")
|
||||||
|
assert result.summary == "Heading and body."
|
||||||
94
tests/test_folder_quota.py
Normal file
94
tests/test_folder_quota.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
"""Folder quota helpers."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.billing.quota import (
|
||||||
|
check_folder_quota,
|
||||||
|
add_token_usage,
|
||||||
|
QuotaExceeded,
|
||||||
|
)
|
||||||
|
from app.models import MonthlyTokenUsage
|
||||||
|
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_folder_quota_free_rejects_above_file_cap(db, test_user_free):
|
||||||
|
with pytest.raises(QuotaExceeded) as exc:
|
||||||
|
await check_folder_quota(
|
||||||
|
user_id=test_user_free.id, tier="free", estimated_files=500, db=db
|
||||||
|
)
|
||||||
|
assert exc.value.reason == "max_files"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_folder_quota_free_passes_under_cap(db, test_user_free):
|
||||||
|
# No raise
|
||||||
|
await check_folder_quota(
|
||||||
|
user_id=test_user_free.id, tier="free", estimated_files=50, db=db
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_folder_quota_rejects_when_monthly_exhausted(db, test_user_free):
|
||||||
|
ym = datetime.now(timezone.utc).strftime("%Y-%m")
|
||||||
|
db.add(MonthlyTokenUsage(
|
||||||
|
user_id=test_user_free.id, year_month=ym, feature="folder_index", tokens_used=100_000
|
||||||
|
))
|
||||||
|
await db.commit()
|
||||||
|
with pytest.raises(QuotaExceeded) as exc:
|
||||||
|
await check_folder_quota(
|
||||||
|
user_id=test_user_free.id, tier="free", estimated_files=10, db=db
|
||||||
|
)
|
||||||
|
assert exc.value.reason == "monthly_tokens"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_folder_quota_power_unlimited(db, test_user_power):
|
||||||
|
await check_folder_quota(
|
||||||
|
user_id=test_user_power.id, tier="power", estimated_files=999_999, db=db
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_token_usage_atomic_increment(db, test_user_free):
|
||||||
|
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=1500, db=db)
|
||||||
|
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=2500, db=db)
|
||||||
|
ym = datetime.now(timezone.utc).strftime("%Y-%m")
|
||||||
|
row = (await db.execute(
|
||||||
|
select(MonthlyTokenUsage).where(
|
||||||
|
MonthlyTokenUsage.user_id == test_user_free.id,
|
||||||
|
MonthlyTokenUsage.year_month == ym,
|
||||||
|
MonthlyTokenUsage.feature == "folder_index",
|
||||||
|
)
|
||||||
|
)).scalar_one()
|
||||||
|
assert row.tokens_used == 4000
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_token_usage_returns_exhausted_when_over_cap(db, test_user_free):
|
||||||
|
result = await add_token_usage(
|
||||||
|
user_id=test_user_free.id, feature="folder_index", tokens=150_000, db=db, cap=100_000
|
||||||
|
)
|
||||||
|
assert result.exhausted is True
|
||||||
|
assert result.tokens_used == 150_000
|
||||||
|
|
||||||
|
|
||||||
|
def test_quota_check_endpoint_rejects(client, auth_headers_free):
|
||||||
|
res = client.post(
|
||||||
|
"/api/v1/billing/quota/check",
|
||||||
|
json={"feature": "folder_index", "estimated_files": 500},
|
||||||
|
headers=auth_headers_free,
|
||||||
|
)
|
||||||
|
assert res.status_code == 402
|
||||||
|
body = res.json()
|
||||||
|
assert body["detail"]["reason"] == "max_files"
|
||||||
|
|
||||||
|
|
||||||
|
def test_quota_check_endpoint_passes(client, auth_headers_free):
|
||||||
|
res = client.post(
|
||||||
|
"/api/v1/billing/quota/check",
|
||||||
|
json={"feature": "folder_index", "estimated_files": 50},
|
||||||
|
headers=auth_headers_free,
|
||||||
|
)
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.json() == {"ok": True}
|
||||||
@@ -328,7 +328,7 @@ def _make_gmail_message(
|
|||||||
class TestGmailClientFetchMessages:
|
class TestGmailClientFetchMessages:
|
||||||
"""GmailClient.fetch_messages tests with mocked Google API."""
|
"""GmailClient.fetch_messages tests with mocked Google API."""
|
||||||
|
|
||||||
def _make_client(self) -> "GmailClient":
|
def _make_client(self):
|
||||||
from app.integrations.gmail import GmailClient
|
from app.integrations.gmail import GmailClient
|
||||||
return GmailClient(_TOKEN_DICT)
|
return GmailClient(_TOKEN_DICT)
|
||||||
|
|
||||||
@@ -509,7 +509,7 @@ def _make_graph_teams_message(
|
|||||||
class TestMSGraphClientFetchEmails:
|
class TestMSGraphClientFetchEmails:
|
||||||
"""MSGraphClient.fetch_emails tests with mocked httpx."""
|
"""MSGraphClient.fetch_emails tests with mocked httpx."""
|
||||||
|
|
||||||
def _make_client(self) -> "MSGraphClient":
|
def _make_client(self):
|
||||||
from app.integrations.ms_graph import MSGraphClient
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
return MSGraphClient(_MS_TOKEN_DICT)
|
return MSGraphClient(_MS_TOKEN_DICT)
|
||||||
|
|
||||||
@@ -608,7 +608,7 @@ class TestMSGraphClientFetchEmails:
|
|||||||
class TestMSGraphClientFetchMessages:
|
class TestMSGraphClientFetchMessages:
|
||||||
"""MSGraphClient.fetch_messages (Teams) tests."""
|
"""MSGraphClient.fetch_messages (Teams) tests."""
|
||||||
|
|
||||||
def _make_client(self) -> "MSGraphClient":
|
def _make_client(self):
|
||||||
from app.integrations.ms_graph import MSGraphClient
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
return MSGraphClient(_MS_TOKEN_DICT)
|
return MSGraphClient(_MS_TOKEN_DICT)
|
||||||
|
|
||||||
|
|||||||
69
tests/test_manifest_injection.py
Normal file
69
tests/test_manifest_injection.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.deep_agent import format_folder_manifest, MANIFEST_TOKEN_BUDGET
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_folder_manifest_basic():
|
||||||
|
manifest = {
|
||||||
|
"folderPath": "D:\\Acme",
|
||||||
|
"lastScannedAt": "2h ago",
|
||||||
|
"files": [
|
||||||
|
{"relPath": "briefs/kickoff.md", "kind": "text", "summary": "Kickoff notes; scope and deadlines."},
|
||||||
|
{"relPath": "logos/logo-v3.png", "kind": "image", "summary": "Final logo on white."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
out = format_folder_manifest(manifest)
|
||||||
|
assert "<linked_folder>" in out
|
||||||
|
assert "/briefs/kickoff.md" in out or "briefs/kickoff.md" in out
|
||||||
|
assert "[text]" in out
|
||||||
|
assert "[image]" in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_folder_manifest_truncates_past_budget():
|
||||||
|
files = [
|
||||||
|
{"relPath": f"f{i}.md", "kind": "text", "summary": "x" * 100, "mtimeMs": i}
|
||||||
|
for i in range(2000)
|
||||||
|
]
|
||||||
|
out = format_folder_manifest({"folderPath": "p", "lastScannedAt": "now", "files": files})
|
||||||
|
assert "more files omitted" in out
|
||||||
|
# Rough token check
|
||||||
|
assert len(out) // 4 < MANIFEST_TOKEN_BUDGET + 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_folder_manifest_null_returns_empty():
|
||||||
|
assert format_folder_manifest(None) == ""
|
||||||
|
assert format_folder_manifest({"files": []}) == ""
|
||||||
|
|
||||||
|
|
||||||
|
async def test_brief_multi_project_manifest_top_5_per_project():
|
||||||
|
fake_response = [
|
||||||
|
{
|
||||||
|
"projectId": "p1", "projectName": "Acme", "folderPath": "/a",
|
||||||
|
"lastScannedAt": "now",
|
||||||
|
"files": [
|
||||||
|
{"relPath": f"f{i}.md", "kind": "text", "summary": "s", "mtimeMs": i}
|
||||||
|
for i in range(10)
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"projectId": "p2", "projectName": "Beta", "folderPath": "/b",
|
||||||
|
"lastScannedAt": "now",
|
||||||
|
"files": [{"relPath": "x.md", "kind": "text", "summary": "s", "mtimeMs": 1}],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
with patch(
|
||||||
|
"app.core.deep_agent.execute_on_client",
|
||||||
|
new=AsyncMock(return_value={"projects": fake_response}),
|
||||||
|
):
|
||||||
|
from app.core.deep_agent import build_brief_multi_project_manifest
|
||||||
|
out = await build_brief_multi_project_manifest()
|
||||||
|
# Project 1 has 10 files, only top 5 by mtimeMs should appear
|
||||||
|
assert out.count("[p1]") <= 5
|
||||||
|
# Project 2 has 1 file, must appear
|
||||||
|
assert "[p2]" in out or "Beta" in out
|
||||||
196
tests/test_ws_index_session.py
Normal file
196
tests/test_ws_index_session.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
"""Tests for WS folder index_session handlers (Task 9).
|
||||||
|
|
||||||
|
Tests the three handler functions directly with a minimal fake WebSocket so
|
||||||
|
no real WS connection or LLM call is made.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.api.routes.device_ws import (
|
||||||
|
_handle_index_session_start,
|
||||||
|
_handle_index_file_batch,
|
||||||
|
_handle_index_session_cancel,
|
||||||
|
_index_sessions,
|
||||||
|
)
|
||||||
|
from app.billing.quota import add_token_usage
|
||||||
|
from app.core.folder_indexer import IndexResult
|
||||||
|
from app.models import MonthlyTokenUsage
|
||||||
|
from app.schemas import WsFrameType
|
||||||
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
USER_ID = TEST_USER_IDS["free"]
|
||||||
|
POWER_USER_ID = TEST_USER_IDS["power"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fake WebSocket ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _FakeWebSocket:
|
||||||
|
"""Minimal WebSocket stand-in that records send_text calls."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.sent: list[dict] = []
|
||||||
|
|
||||||
|
async def send_text(self, text: str) -> None:
|
||||||
|
self.sent.append(json.loads(text))
|
||||||
|
|
||||||
|
def sent_types(self) -> list[str]:
|
||||||
|
return [f["type"] for f in self.sent]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_session_id() -> str:
|
||||||
|
import uuid
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_summarize_text_factory(summary: str = "A test summary.", tokens: int = 100):
|
||||||
|
"""Return an AsyncMock that resolves to a fixed IndexResult."""
|
||||||
|
async def _fake(**kwargs) -> IndexResult:
|
||||||
|
return IndexResult(summary=summary, tokens_used=tokens)
|
||||||
|
return _fake
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(autouse=True)
|
||||||
|
async def _clean_sessions():
|
||||||
|
"""Ensure _index_sessions is empty before and after each test."""
|
||||||
|
_index_sessions.clear()
|
||||||
|
yield
|
||||||
|
_index_sessions.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def test_index_session_happy_path(db_session):
|
||||||
|
"""start + batch of 2 text files → 2 index_file_result + 1 progress + 1 done(completed)."""
|
||||||
|
ws = _FakeWebSocket()
|
||||||
|
session_id = _make_session_id()
|
||||||
|
|
||||||
|
# Register session.
|
||||||
|
await _handle_index_session_start(ws, USER_ID, {
|
||||||
|
"sessionId": session_id,
|
||||||
|
"projectId": "proj-1",
|
||||||
|
"totalFiles": 2,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Verify session was registered.
|
||||||
|
assert session_id in _index_sessions
|
||||||
|
assert _index_sessions[session_id]["total"] == 2
|
||||||
|
assert _index_sessions[session_id]["processed"] == 0
|
||||||
|
# No response frames expected for session_start.
|
||||||
|
assert ws.sent == []
|
||||||
|
|
||||||
|
# Send batch of 2 text files — patch summarize_text so no LLM call needed.
|
||||||
|
with patch(
|
||||||
|
"app.api.routes.device_ws._handle_index_file_batch.__globals__",
|
||||||
|
# We patch the module-level function in folder_indexer instead:
|
||||||
|
) if False else patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory()):
|
||||||
|
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
|
||||||
|
# Wire db_session into the context manager.
|
||||||
|
mock_cm = AsyncMock()
|
||||||
|
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
|
||||||
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_async_session.return_value = mock_cm
|
||||||
|
|
||||||
|
await _handle_index_file_batch(ws, USER_ID, {
|
||||||
|
"sessionId": session_id,
|
||||||
|
"files": [
|
||||||
|
{"relPath": "README.md", "kind": "text", "content": "hello", "ext": ".md"},
|
||||||
|
{"relPath": "notes.txt", "kind": "text", "content": "world", "ext": ".txt"},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
|
||||||
|
types = ws.sent_types()
|
||||||
|
# Expect 2 file results + 1 progress + 1 done(completed).
|
||||||
|
assert types.count(WsFrameType.index_file_result) == 2
|
||||||
|
assert types.count(WsFrameType.index_session_progress) == 1
|
||||||
|
assert types.count(WsFrameType.index_session_done) == 1
|
||||||
|
|
||||||
|
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||||
|
assert done_frame["status"] == "completed"
|
||||||
|
|
||||||
|
progress_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_progress)
|
||||||
|
assert progress_frame["processed"] == 2
|
||||||
|
assert progress_frame["total"] == 2
|
||||||
|
|
||||||
|
# Verify session cleaned up.
|
||||||
|
assert session_id not in _index_sessions
|
||||||
|
|
||||||
|
|
||||||
|
async def test_index_session_cancel(db_session):
|
||||||
|
"""start then cancel → index_session_done(cancelled)."""
|
||||||
|
ws = _FakeWebSocket()
|
||||||
|
session_id = _make_session_id()
|
||||||
|
|
||||||
|
await _handle_index_session_start(ws, USER_ID, {
|
||||||
|
"sessionId": session_id,
|
||||||
|
"totalFiles": 5,
|
||||||
|
})
|
||||||
|
assert session_id in _index_sessions
|
||||||
|
|
||||||
|
await _handle_index_session_cancel(ws, {"sessionId": session_id})
|
||||||
|
|
||||||
|
types = ws.sent_types()
|
||||||
|
assert WsFrameType.index_session_done in types
|
||||||
|
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||||
|
assert done_frame["status"] == "cancelled"
|
||||||
|
|
||||||
|
# Session should be cleaned up.
|
||||||
|
assert session_id not in _index_sessions
|
||||||
|
|
||||||
|
|
||||||
|
async def test_index_session_quota_exceeded(db_session):
|
||||||
|
"""Pre-fill usage to cap → batch one file → index_session_done(quota_exceeded)."""
|
||||||
|
ws = _FakeWebSocket()
|
||||||
|
session_id = _make_session_id()
|
||||||
|
|
||||||
|
# Pre-fill monthly token usage to the free-tier cap (100_000).
|
||||||
|
ym = datetime.now(timezone.utc).strftime("%Y-%m")
|
||||||
|
db_session.add(MonthlyTokenUsage(
|
||||||
|
user_id=USER_ID,
|
||||||
|
year_month=ym,
|
||||||
|
feature="folder_index",
|
||||||
|
tokens_used=100_000, # free tier cap exactly
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
await _handle_index_session_start(ws, USER_ID, {
|
||||||
|
"sessionId": session_id,
|
||||||
|
"totalFiles": 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
with patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory(tokens=1)):
|
||||||
|
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
|
||||||
|
mock_cm = AsyncMock()
|
||||||
|
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
|
||||||
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_async_session.return_value = mock_cm
|
||||||
|
|
||||||
|
await _handle_index_file_batch(ws, USER_ID, {
|
||||||
|
"sessionId": session_id,
|
||||||
|
"files": [
|
||||||
|
{"relPath": "file.md", "kind": "text", "content": "content", "ext": ".md"},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
|
||||||
|
types = ws.sent_types()
|
||||||
|
# Should have 1 file result (success) then done(quota_exceeded).
|
||||||
|
assert WsFrameType.index_file_result in types
|
||||||
|
assert WsFrameType.index_session_done in types
|
||||||
|
|
||||||
|
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||||
|
assert done_frame["status"] == "quota_exceeded"
|
||||||
|
|
||||||
|
# Session should be cleaned up.
|
||||||
|
assert session_id not in _index_sessions
|
||||||
Reference in New Issue
Block a user