Compare commits
44 Commits
feature/ba
...
70c19d3064
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
70c19d3064 | ||
|
|
886730b47e | ||
|
|
052c7e3741 | ||
|
|
d63fd5f3b9 | ||
|
|
5e42b2abb1 | ||
|
|
2b71469e86 | ||
|
|
6188ae15b3 | ||
|
|
e1db7cdf06 | ||
|
|
c53f08229c | ||
|
|
3e2d80d5bb | ||
|
|
cc0e258e8c | ||
|
|
12e203e63d | ||
|
|
ffcd7390f0 | ||
|
|
91e880f9d4 | ||
|
|
7d47ca54be | ||
|
|
956fa88853 | ||
|
|
fb2f59ccea | ||
|
|
56dbb7f4cd | ||
|
|
506f517851 | ||
|
|
520c186991 | ||
|
|
582bf27deb | ||
|
|
2aeb453229 | ||
|
|
b7a4edac90 | ||
|
|
822b4cd8b1 | ||
|
|
ab24fc4c91 | ||
|
|
a98e99f7a2 | ||
|
|
a0ff285bcd | ||
|
|
177c1a87dd | ||
|
|
441a4ea05c | ||
|
|
a693a64bf5 | ||
|
|
67562b8092 | ||
|
|
6f4c68b359 | ||
|
|
c20c6d7853 | ||
|
|
6787e690ba | ||
|
|
cb8f56d909 | ||
|
|
2c7cac9e03 | ||
|
|
ea9094f47f | ||
|
|
d5fea95561 | ||
|
|
0b5ef48463 | ||
|
|
ca8721e1ac | ||
|
|
f658e5e6a3 | ||
|
|
341ee140e5 | ||
|
|
741b9b87fb | ||
|
|
2d8abb6311 |
25
.env.example
25
.env.example
@@ -21,6 +21,8 @@ OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
GOOGLE_API_KEY=
|
||||
CEREBRAS_API_KEY=
|
||||
GROQ_API_KEY=
|
||||
DEEPSEEK_API_KEY=
|
||||
|
||||
# Default model used by any agent that does not have a specific override below.
|
||||
LLM_MODEL=gpt-5-mini
|
||||
@@ -50,9 +52,32 @@ LLM_MODEL_UNIFIED_PROCESSOR=
|
||||
# Cloud-processor — fetches and processes data from cloud connectors.
|
||||
LLM_MODEL_CLOUD_PROCESSOR=
|
||||
|
||||
# Brief-agent — produces home and project text briefs.
|
||||
# A small model (e.g. gpt-4o-mini) is sufficient.
|
||||
# LLM_MODEL_BRIEF_AGENT=
|
||||
|
||||
# Task-brief-agent — per-task deep research (Stage 1 executive assistant).
|
||||
# Needs tool-use + reasoning; a capable model recommended (e.g. gpt-4o, gemini-2.5-flash).
|
||||
# LLM_MODEL_TASK_BRIEF_AGENT=
|
||||
|
||||
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
|
||||
LLM_MODEL_SETUP_AGENT=
|
||||
|
||||
# Memory-extractor — Mem0-style extract/decide pipeline (Phase 2).
|
||||
# Defaults to gpt-4o-mini when empty (fast + cheap, temperature=0).
|
||||
LLM_MODEL_MEMORY_EXTRACTOR=
|
||||
|
||||
# Memory-miner — proactive pattern mining from episodic history (Phase 5, Power+ only).
|
||||
# Defaults to gpt-4o-mini when empty.
|
||||
LLM_MODEL_MEMORY_MINER=
|
||||
|
||||
# Memory-auditor — weekly contradiction scan + relation label canonicalization (Phase 7).
|
||||
# Defaults to LLM_MODEL when empty (a reasoning-capable model is recommended).
|
||||
LLM_MODEL_MEMORY_AUDITOR=
|
||||
|
||||
# Scheduler — set to false to disable memory cron jobs (automatically false in tests).
|
||||
SCHEDULER_ENABLED=true
|
||||
|
||||
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
||||
STRIPE_SECRET_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -28,6 +28,9 @@ tests/fixtures/private*/
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
|
||||
# Smoke scripts (dev-only, not for CI)
|
||||
scripts/smoke_*.py
|
||||
Thumbs.db
|
||||
|
||||
# Claude Code
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
## DEV
|
||||
Run in DEV with command:
|
||||
```
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload --log-config logging.conf
|
||||
```
|
||||
54
alembic/versions/005_associative_pgvector.py
Normal file
54
alembic/versions/005_associative_pgvector.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Phase 1 — confirm pgvector activation on memory_associative.
|
||||
|
||||
Migration 004 created the embedding column as vector(1536) and added the
|
||||
IVFFlat index. This migration is the Phase-1 checkpoint:
|
||||
1. Ensures the pgvector extension is enabled (idempotent).
|
||||
2. Ensures the canonical Phase-1 IVFFlat index exists under the name
|
||||
memory_associative_embedding_idx (creates it only if absent).
|
||||
|
||||
Revision ID: 005
|
||||
Revises: 9a1f2d0b6c7e
|
||||
Create Date: 2026-04-15
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "005"
|
||||
down_revision: Union[str, None] = "e04100e88ace"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Ensure pgvector extension is enabled (also done in 004, idempotent).
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||
|
||||
# Ensure the canonical Phase-1 IVFFlat index exists.
|
||||
# 004 may have created ix_memory_associative_embedding; this adds the
|
||||
# Phase-1 name memory_associative_embedding_idx if it is missing.
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_indexes
|
||||
WHERE tablename = 'memory_associative'
|
||||
AND indexname = 'memory_associative_embedding_idx'
|
||||
) THEN
|
||||
CREATE INDEX memory_associative_embedding_idx
|
||||
ON memory_associative
|
||||
USING ivfflat (embedding vector_cosine_ops)
|
||||
WITH (lists = 100);
|
||||
END IF;
|
||||
END $$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS memory_associative_embedding_idx;")
|
||||
74
alembic/versions/006_memory_relations.py
Normal file
74
alembic/versions/006_memory_relations.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Add memory_relations table (Phase 3 — relational tier).
|
||||
|
||||
Revision ID: 006
|
||||
Revises: 1f5975a4f3f4
|
||||
Create Date: 2026-04-16
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "006"
|
||||
down_revision: Union[str, None] = "1f5975a4f3f4"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"memory_relations",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=False),
|
||||
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("subject_label", sa.String(128), nullable=False),
|
||||
sa.Column("subject_type", sa.String(32), nullable=False),
|
||||
sa.Column("predicate", sa.String(64), nullable=False),
|
||||
sa.Column("object_label", sa.String(128), nullable=False),
|
||||
sa.Column("object_type", sa.String(32), nullable=False),
|
||||
sa.Column("confidence", sa.Float, nullable=False, server_default="0.7"),
|
||||
sa.Column(
|
||||
"source_episode_id",
|
||||
postgresql.UUID(as_uuid=False),
|
||||
sa.ForeignKey("memory_episodic.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("notes_encrypted", sa.LargeBinary, nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.Column("last_confirmed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
"memory_relations_user_subject_idx",
|
||||
"memory_relations",
|
||||
["user_id", "subject_label"],
|
||||
)
|
||||
op.create_index(
|
||||
"memory_relations_user_predicate_idx",
|
||||
"memory_relations",
|
||||
["user_id", "predicate"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("memory_relations_user_predicate_idx", "memory_relations")
|
||||
op.drop_index("memory_relations_user_subject_idx", "memory_relations")
|
||||
op.drop_table("memory_relations")
|
||||
38
alembic/versions/1f5975a4f3f4_add_extraction_queue.py
Normal file
38
alembic/versions/1f5975a4f3f4_add_extraction_queue.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""add extraction_queue
|
||||
|
||||
Revision ID: 1f5975a4f3f4
|
||||
Revises: 005
|
||||
Create Date: 2026-04-16 17:26:25.790870
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1f5975a4f3f4'
|
||||
down_revision: Union[str, None] = '005'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'extraction_queue',
|
||||
sa.Column('id', sa.Uuid(as_uuid=False), nullable=False),
|
||||
sa.Column('user_id', sa.Uuid(as_uuid=False), nullable=False),
|
||||
sa.Column('episode_id', sa.Uuid(as_uuid=False), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
op.create_index(op.f('ix_extraction_queue_user_id'), 'extraction_queue', ['user_id'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f('ix_extraction_queue_user_id'), table_name='extraction_queue')
|
||||
op.drop_table('extraction_queue')
|
||||
46
alembic/versions/d6e3f4a5b6c7_folder_index_tables.py
Normal file
46
alembic/versions/d6e3f4a5b6c7_folder_index_tables.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Add token tracking columns for folder integration.
|
||||
|
||||
Revision ID: d6e3f4a5b6c7
|
||||
Revises: 006
|
||||
Create Date: 2026-05-11 00:00:00.000000
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d6e3f4a5b6c7"
|
||||
down_revision: Union[str, None] = "006"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"agent_run_logs",
|
||||
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
op.create_table(
|
||||
"monthly_token_usage",
|
||||
sa.Column("user_id", UUID(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("year_month", sa.String(7), nullable=False),
|
||||
sa.Column("feature", sa.String(64), nullable=False),
|
||||
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.PrimaryKeyConstraint("user_id", "year_month", "feature"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_monthly_token_usage_user_month",
|
||||
"monthly_token_usage",
|
||||
["user_id", "year_month"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_monthly_token_usage_user_month", table_name="monthly_token_usage")
|
||||
op.drop_table("monthly_token_usage")
|
||||
op.drop_column("agent_run_logs", "tokens_used")
|
||||
52
app/agents/client_agent.py
Normal file
52
app/agents/client_agent.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Client agent — read-only tools for the clients table."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
|
||||
@tool
|
||||
async def list_clients(search: str = "", limit: int = 20) -> str:
|
||||
"""List clients, optionally filtered by a name/email substring search.
|
||||
|
||||
search: optional substring to match against client name or email.
|
||||
limit: max rows to return (default 20).
|
||||
"""
|
||||
filters: dict[str, Any] = {"limit": limit}
|
||||
if search:
|
||||
filters["search"] = search
|
||||
|
||||
result = await execute_on_client(action="select", table="clients", filters=filters)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No clients found."
|
||||
lines = [
|
||||
f"- {r.get('name', '?')} (id: {r.get('id')}, email: {r.get('email', '')}, "
|
||||
f"company: {r.get('company', '')})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} client(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def get_client(id: str) -> str:
|
||||
"""Get full details for one client by UUID.
|
||||
|
||||
id: the client's UUID.
|
||||
"""
|
||||
if not id:
|
||||
return "Client id is required."
|
||||
|
||||
result = await execute_on_client(action="get", table="clients", data={"id": id})
|
||||
row = result.get("row") or result.get("rows", [None])[0] if result else None
|
||||
if not row:
|
||||
return f"Client '{id}' not found."
|
||||
return f"Client details:\n{json.dumps(row, ensure_ascii=False, indent=2)}"
|
||||
|
||||
|
||||
CLIENT_TOOLS: list[Any] = [list_clients, get_client]
|
||||
168
app/agents/folder_agent.py
Normal file
168
app/agents/folder_agent.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Scoped file-read and search tools for the project folder feature."""
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.folder_indexer import _extract_docx_text, _extract_pdf_text
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
# Cap returned slice size to keep tool output under control.
|
||||
_MAX_RETURN_CHARS = 50_000
|
||||
_MAX_SEARCH_MATCHES = 20
|
||||
|
||||
|
||||
def _is_unsafe_path(rel: str) -> bool:
|
||||
if not rel:
|
||||
return True
|
||||
norm = rel.replace("\\", "/")
|
||||
if norm.startswith("/"):
|
||||
return True
|
||||
# Windows drive letter
|
||||
if len(rel) >= 2 and rel[1] == ":":
|
||||
return True
|
||||
parts = norm.split("/")
|
||||
return ".." in parts
|
||||
|
||||
|
||||
async def _fetch_file(project_id: str, relative_path: str, offset: int, length: int) -> dict:
|
||||
"""Return the raw Electron tool_result dict for a file read."""
|
||||
return await execute_on_client(
|
||||
action="read_project_folder_file",
|
||||
data={
|
||||
"projectId": project_id,
|
||||
"relativePath": relative_path,
|
||||
"offset": offset,
|
||||
"length": length,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _decode(result: dict) -> tuple[str, str, int]:
|
||||
"""Decode a tool_result into (text, kind, total_size). For pdf/docx,
|
||||
extracts text from base64. For images, returns a placeholder string.
|
||||
For text, content is already a sliced utf-8 string.
|
||||
"""
|
||||
kind = result.get("kind", "text")
|
||||
content = result.get("content", "") or ""
|
||||
total = int(result.get("totalSize", 0) or 0)
|
||||
if kind == "image":
|
||||
return ("[Image file — cannot be navigated as text. See manifest summary.]", kind, total)
|
||||
if kind == "pdf":
|
||||
return (_extract_pdf_text(content), kind, total)
|
||||
if kind == "docx":
|
||||
return (_extract_docx_text(content), kind, total)
|
||||
return (content, kind, total)
|
||||
|
||||
|
||||
@tool
|
||||
async def read_project_folder_file(
|
||||
project_id: str,
|
||||
relative_path: str,
|
||||
offset: int = 0,
|
||||
length: int = _MAX_RETURN_CHARS,
|
||||
) -> str:
|
||||
"""Read a slice of a file inside the project's linked folder.
|
||||
|
||||
Args:
|
||||
project_id: project ID.
|
||||
relative_path: path relative to the linked folder root.
|
||||
offset: char offset to start reading from (0 = beginning).
|
||||
length: max chars to return. Default 50000. Use smaller values to save tokens.
|
||||
|
||||
Returns text content slice with a header showing position. Header tells you
|
||||
when more content is available; call again with the suggested next offset.
|
||||
|
||||
For PDF / DOCX files the backend extracts text first, then applies offset/length
|
||||
on the extracted text. For images returns a placeholder; navigate with the
|
||||
manifest summary instead.
|
||||
"""
|
||||
if _is_unsafe_path(relative_path):
|
||||
return "Access denied"
|
||||
|
||||
result = await _fetch_file(project_id, relative_path, offset, length)
|
||||
text, kind, total_size = _decode(result)
|
||||
|
||||
if not text and kind in ("missing", "error"):
|
||||
return f"File not found or unreadable: {relative_path}"
|
||||
|
||||
if kind in ("pdf", "docx"):
|
||||
# Backend extracted full text — apply offset/length on chars.
|
||||
sliced = text[offset:offset + length]
|
||||
slice_end = min(offset + length, len(text))
|
||||
header = (
|
||||
f"[file={relative_path} kind={kind} offset={offset} end={slice_end} "
|
||||
f"totalChars={len(text)}]"
|
||||
)
|
||||
if slice_end < len(text):
|
||||
header += f"\n[More content available — call again with offset={slice_end}.]"
|
||||
return header + "\n" + sliced
|
||||
|
||||
if kind == "text":
|
||||
slice_end = offset + len(text)
|
||||
header = (
|
||||
f"[file={relative_path} kind=text offset={offset} end={slice_end} "
|
||||
f"totalBytes={total_size}]"
|
||||
)
|
||||
if slice_end < total_size:
|
||||
header += f"\n[More content available — call again with offset={slice_end}.]"
|
||||
return header + "\n" + text
|
||||
|
||||
# image or unknown
|
||||
return text
|
||||
|
||||
|
||||
@tool
|
||||
async def search_project_folder_file(
|
||||
project_id: str,
|
||||
relative_path: str,
|
||||
query: str,
|
||||
context_lines: int = 3,
|
||||
) -> str:
|
||||
"""Search a project folder file for a query string (case-insensitive substring).
|
||||
|
||||
Args:
|
||||
project_id: project ID.
|
||||
relative_path: path relative to the linked folder root.
|
||||
query: text to search for.
|
||||
context_lines: number of lines of context around each match (default 3).
|
||||
|
||||
Returns matching line ranges with surrounding context and 1-based line numbers.
|
||||
Capped at 20 matches; if more exist the header shows the total.
|
||||
|
||||
Works on text, code, markdown, PDF (extracted), and DOCX (extracted).
|
||||
Images and binary files are not searchable.
|
||||
"""
|
||||
if _is_unsafe_path(relative_path):
|
||||
return "Access denied"
|
||||
if not query:
|
||||
return "Empty query."
|
||||
|
||||
# For text we still need full file; pass length=very large.
|
||||
result = await _fetch_file(project_id, relative_path, offset=0, length=10_000_000)
|
||||
text, kind, _ = _decode(result)
|
||||
|
||||
if not text and kind in ("missing", "error"):
|
||||
return f"File not found or unreadable: {relative_path}"
|
||||
if kind == "image":
|
||||
return "Cannot search inside images."
|
||||
|
||||
lines = text.splitlines()
|
||||
q = query.lower()
|
||||
matches = [i for i, line in enumerate(lines) if q in line.lower()]
|
||||
if not matches:
|
||||
return f"No matches for '{query}' in {relative_path}."
|
||||
|
||||
shown = matches[:_MAX_SEARCH_MATCHES]
|
||||
snippets: list[str] = []
|
||||
for i in shown:
|
||||
start = max(0, i - context_lines)
|
||||
end = min(len(lines), i + context_lines + 1)
|
||||
block = "\n".join(f"{n + 1:5d}: {lines[n]}" for n in range(start, end))
|
||||
snippets.append(block)
|
||||
|
||||
header = f"[file={relative_path} matches={len(matches)} showing={len(shown)} query='{query}']"
|
||||
body = "\n---\n".join(snippets)
|
||||
return header + "\n" + body
|
||||
|
||||
|
||||
FOLDER_TOOLS = [read_project_folder_file, search_project_folder_file]
|
||||
@@ -1,13 +1,14 @@
|
||||
"""Note agent — Markdown note management (list, get, create, update, delete)."""
|
||||
"""Note agent — Markdown note management (list, get, create, update, propose edit)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.llm import embed
|
||||
from app.core.note_summarizer import generate_note_summary
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
@@ -19,9 +20,21 @@ def _is_uuid(value: str) -> bool:
|
||||
return bool(_UUID_RE.match(value))
|
||||
|
||||
|
||||
def _fmt_summary(row: dict) -> str:
|
||||
summary = (row.get("aiSummary") or row.get("ai_summary") or "").strip()
|
||||
if summary:
|
||||
return f" — {summary}"
|
||||
snippet = (row.get("content") or "")[:120].replace("\n", " ").strip()
|
||||
return f" — {snippet}" if snippet else ""
|
||||
|
||||
|
||||
@tool
|
||||
async def list_notes(project_id: str = "") -> str:
|
||||
"""List notes, optionally scoped to a project by project_id."""
|
||||
"""List notes with AI summaries, optionally scoped to a project by project_id.
|
||||
|
||||
Returns id, title, and ai_summary for each note so you can decide which
|
||||
note to read in full with get_note before creating or updating.
|
||||
"""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
@@ -31,7 +44,7 @@ async def list_notes(project_id: str = "") -> str:
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No notes found."
|
||||
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
|
||||
lines = [f" - [{r['id']}] {r['title']}{_fmt_summary(r)}" for r in rows]
|
||||
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@@ -66,14 +79,10 @@ async def create_note(
|
||||
},
|
||||
)
|
||||
row = result["row"]
|
||||
# Index the note content in the vector store.
|
||||
vector = await embed(content)
|
||||
await execute_on_client(
|
||||
action="vector_upsert",
|
||||
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
|
||||
vector=vector,
|
||||
)
|
||||
return f"Note created: '{row['title']}' (id: {row['id']})."
|
||||
note_id: str = row["id"]
|
||||
# Generate summary asynchronously — fire-and-forget.
|
||||
asyncio.create_task(_refresh_summary(note_id, title, content))
|
||||
return f"Note created: '{row['title']}' (id: {note_id})."
|
||||
|
||||
|
||||
@tool
|
||||
@@ -82,7 +91,8 @@ async def update_note(
|
||||
title: str = "",
|
||||
content: str = "",
|
||||
) -> str:
|
||||
"""Update an existing note. Only pass fields that should change.
|
||||
"""Update an existing note directly (no approval required).
|
||||
Use propose_note_edit instead when human review is needed.
|
||||
note_id: UUID of the note (required)
|
||||
If you need to preserve existing content, call get_note first.
|
||||
"""
|
||||
@@ -97,17 +107,63 @@ async def update_note(
|
||||
data={"id": note_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
# Re-index if content changed.
|
||||
if content:
|
||||
vector = await embed(content)
|
||||
await execute_on_client(
|
||||
action="vector_upsert",
|
||||
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
|
||||
vector=vector,
|
||||
)
|
||||
new_title = title or row.get("title", "")
|
||||
asyncio.create_task(_refresh_summary(note_id, new_title, content))
|
||||
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
||||
|
||||
|
||||
@tool
|
||||
async def propose_note_edit(
|
||||
note_id: str,
|
||||
edit_type: str,
|
||||
proposed_content: str,
|
||||
reasoning: str = "",
|
||||
anchor_before: str = "",
|
||||
anchor_text: str = "",
|
||||
agent_id: str = "",
|
||||
run_id: str = "",
|
||||
) -> str:
|
||||
"""Propose an AI edit to an existing note, pending human approval.
|
||||
|
||||
Use this instead of update_note when review_required is true.
|
||||
The user will see the proposal highlighted before it is merged.
|
||||
|
||||
note_id: UUID of the target note (required)
|
||||
edit_type: 'append' | 'insert' | 'replace'
|
||||
- append: adds proposed_content at the end of the note
|
||||
- insert: inserts proposed_content immediately after anchor_before text
|
||||
- replace: replaces the first occurrence of anchor_text with proposed_content
|
||||
proposed_content: the new Markdown text to add or substitute (required)
|
||||
reasoning: brief explanation shown to the user (recommended)
|
||||
anchor_before: for 'insert' — the text snippet that precedes the insertion point
|
||||
anchor_text: for 'replace' — the exact text to be replaced
|
||||
agent_id: agent identifier (for traceability)
|
||||
run_id: run identifier (for traceability)
|
||||
"""
|
||||
if edit_type not in ("append", "insert", "replace"):
|
||||
return f"Invalid edit_type '{edit_type}'. Use 'append', 'insert', or 'replace'."
|
||||
|
||||
result = await execute_on_client(
|
||||
action="propose_note_edit",
|
||||
data={
|
||||
"noteId": note_id,
|
||||
"type": edit_type,
|
||||
"proposedContent": proposed_content,
|
||||
"reasoning": reasoning or None,
|
||||
"anchorBefore": anchor_before or None,
|
||||
"anchorText": anchor_text or None,
|
||||
"agentId": agent_id or None,
|
||||
"runId": run_id or None,
|
||||
},
|
||||
)
|
||||
edit_id = result.get("id", "?")
|
||||
return (
|
||||
f"Edit proposal created (id: {edit_id}) for note {note_id}. "
|
||||
f"Status: pending user approval."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_note(note_id: str) -> str:
|
||||
"""Delete a note permanently by its UUID."""
|
||||
@@ -115,10 +171,36 @@ async def delete_note(note_id: str) -> str:
|
||||
return f"Note {note_id} deleted."
|
||||
|
||||
|
||||
async def _refresh_summary(note_id: str, title: str, content: str) -> None:
|
||||
"""Generate and persist the AI summary for a note. Fire-and-forget."""
|
||||
try:
|
||||
summary = await generate_note_summary(title, content)
|
||||
if summary:
|
||||
await execute_on_client(
|
||||
action="update",
|
||||
table="notes",
|
||||
data={
|
||||
"id": note_id,
|
||||
"updates": {
|
||||
"aiSummary": summary,
|
||||
"aiSummaryUpdatedAt": int(__import__("time").time() * 1000),
|
||||
},
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass # fire-and-forget; errors logged by generate_note_summary
|
||||
|
||||
|
||||
NOTE_TOOLS: list[Any] = [
|
||||
list_notes,
|
||||
get_note,
|
||||
create_note,
|
||||
update_note,
|
||||
propose_note_edit,
|
||||
delete_note,
|
||||
]
|
||||
|
||||
NOTE_READ_TOOLS: list[Any] = [
|
||||
list_notes,
|
||||
get_note,
|
||||
]
|
||||
|
||||
@@ -125,3 +125,9 @@ PROJECT_TOOLS: list[Any] = [
|
||||
update_project,
|
||||
delete_project,
|
||||
]
|
||||
|
||||
PROJECT_READ_TOOLS: list[Any] = [
|
||||
list_projects,
|
||||
list_all_projects,
|
||||
get_project,
|
||||
]
|
||||
|
||||
63
app/agents/relations_agent.py
Normal file
63
app/agents/relations_agent.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Relations agent — read-only tool wrapping MemoryMiddleware.query_relations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import async_session
|
||||
|
||||
# Injected at tool-factory time by _brief_research_tools(); not a module-level global.
|
||||
# Each tool closure captures the user_id bound at factory time.
|
||||
|
||||
|
||||
def make_query_relations_tool(user_id: str, trace_id: str | None = None) -> Any:
|
||||
"""Return a query_relations tool bound to *user_id*."""
|
||||
|
||||
@tool
|
||||
async def query_relations(
|
||||
subject_label: str = "",
|
||||
predicate: str = "",
|
||||
object_label: str = "",
|
||||
limit: int = 10,
|
||||
) -> str:
|
||||
"""Query the relational memory graph for entity relationships.
|
||||
|
||||
Returns rows where subject ↔ predicate ↔ object match the given filters.
|
||||
All parameters are optional — omit to retrieve all relations up to limit.
|
||||
|
||||
subject_label: entity label on the left side (e.g. a client name, "Acme Corp").
|
||||
predicate: relationship type (e.g. "mentioned_in", "works_at", "related_to").
|
||||
object_label: entity label on the right side (e.g. a project name, "Website Redesign").
|
||||
limit: max rows to return (default 10).
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(
|
||||
"relations_agent: query_relations trace=%s user=%s subject=%r predicate=%r object=%r",
|
||||
trace_id or "-", user_id, subject_label, predicate, object_label,
|
||||
)
|
||||
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
rows = await memory.query_relations(
|
||||
user_id=user_id,
|
||||
subject=subject_label or None,
|
||||
predicate=predicate or None,
|
||||
object_=object_label or None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if not rows:
|
||||
return "No relational memory entries found for the given filters."
|
||||
|
||||
lines = [
|
||||
f"- {r.subject_label} —[{r.predicate}]→ {r.object_label}"
|
||||
+ (f" (confidence: {r.confidence:.2f})" if r.confidence is not None else "")
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} relation(s):\n" + "\n".join(lines)
|
||||
|
||||
return query_relations
|
||||
@@ -26,32 +26,141 @@ def _is_uuid(value: str) -> bool:
|
||||
async def list_tasks(
|
||||
project_id: str = "",
|
||||
status: str = "",
|
||||
priority: str = "",
|
||||
assignee: str = "",
|
||||
search: str = "",
|
||||
order_by: str = "",
|
||||
order_dir: str = "",
|
||||
due_date_from: int = -1,
|
||||
due_date_to: int = -1,
|
||||
created_at_from: int = -1,
|
||||
created_at_to: int = -1,
|
||||
completed_at_from: int = -1,
|
||||
completed_at_to: int = -1,
|
||||
is_ai_suggested: int = -1,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> str:
|
||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||
"""List tasks with optional filters. Returns up to `limit` results (default 50).
|
||||
|
||||
project_id: UUID of the project to scope results to.
|
||||
status: filter by status — todo | in_progress | done.
|
||||
priority: filter by priority — high | medium | low.
|
||||
assignee: substring to match against assignee names. OMIT unless the user explicitly
|
||||
names a person or refers to themselves ("my tasks", "assigned to me", "mine").
|
||||
Do NOT default to the current user.
|
||||
search: substring search across title and description.
|
||||
order_by: sort field — dueDate | priority | createdAt | completedAt.
|
||||
order_dir: asc (default) | desc.
|
||||
due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit.
|
||||
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
|
||||
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
||||
is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any.
|
||||
limit: max rows to return (default 50). Use with offset to paginate.
|
||||
offset: skip first N rows (default 0).
|
||||
|
||||
Tip — combine *_from and *_to for a closed range; pass only one for open-ended.
|
||||
Tip — prefer count_tasks for "how many" questions to avoid listing rows.
|
||||
Tip — for natural-language windows ("today", "tomorrow", "this week", "last month", etc.)
|
||||
take due_date_from / due_date_to verbatim from the DATE CONTEXT block in the system prompt;
|
||||
do not compute boundaries from the current UTC instant.
|
||||
"""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="tasks",
|
||||
filters={
|
||||
"projectId": normalized_project_id or None,
|
||||
"status": status or None,
|
||||
"search": search or None,
|
||||
"orderBy": order_by or None,
|
||||
},
|
||||
)
|
||||
filters: dict[str, Any] = {
|
||||
"projectId": normalized_project_id or None,
|
||||
"status": status or None,
|
||||
"priority": priority or None,
|
||||
"search": search or None,
|
||||
"orderBy": order_by or None,
|
||||
"orderDir": order_dir or None,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
if assignee:
|
||||
filters["assignee"] = assignee
|
||||
if due_date_from != -1:
|
||||
filters["dueDateFrom"] = due_date_from
|
||||
if due_date_to != -1:
|
||||
filters["dueDateTo"] = due_date_to
|
||||
if created_at_from != -1:
|
||||
filters["createdAtFrom"] = created_at_from
|
||||
if created_at_to != -1:
|
||||
filters["createdAtTo"] = created_at_to
|
||||
if completed_at_from != -1:
|
||||
filters["completedAtFrom"] = completed_at_from
|
||||
if completed_at_to != -1:
|
||||
filters["completedAtTo"] = completed_at_to
|
||||
if is_ai_suggested != -1:
|
||||
filters["isAiSuggested"] = is_ai_suggested
|
||||
|
||||
result = await execute_on_client(action="select", table="tasks", filters=filters)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No tasks found matching the given filters."
|
||||
lines = [
|
||||
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
|
||||
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, "
|
||||
f"dueDate: {r.get('dueDate')}, completedAt: {r.get('completedAt')}, "
|
||||
f"projectId: {r.get('projectId')}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def count_tasks(
|
||||
project_id: str = "",
|
||||
status: str = "",
|
||||
priority: str = "",
|
||||
assignee: str = "",
|
||||
search: str = "",
|
||||
due_date_from: int = -1,
|
||||
due_date_to: int = -1,
|
||||
created_at_from: int = -1,
|
||||
created_at_to: int = -1,
|
||||
completed_at_from: int = -1,
|
||||
completed_at_to: int = -1,
|
||||
is_ai_suggested: int = -1,
|
||||
) -> str:
|
||||
"""Count tasks matching the given filters without returning rows.
|
||||
|
||||
Use this instead of list_tasks for "how many" questions — it is much cheaper.
|
||||
Same filter parameters as list_tasks (no limit/offset/order_by needed).
|
||||
assignee: OMIT unless the user explicitly names a person or refers to themselves
|
||||
("my tasks"). Do NOT default to the current user.
|
||||
due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit.
|
||||
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
|
||||
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
||||
Tip — for natural-language windows take due_date_from / due_date_to from the DATE CONTEXT block;
|
||||
do not compute boundaries from the current UTC instant.
|
||||
"""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
filters: dict[str, Any] = {
|
||||
"projectId": normalized_project_id or None,
|
||||
"status": status or None,
|
||||
"priority": priority or None,
|
||||
"search": search or None,
|
||||
}
|
||||
if assignee:
|
||||
filters["assignee"] = assignee
|
||||
if due_date_from != -1:
|
||||
filters["dueDateFrom"] = due_date_from
|
||||
if due_date_to != -1:
|
||||
filters["dueDateTo"] = due_date_to
|
||||
if created_at_from != -1:
|
||||
filters["createdAtFrom"] = created_at_from
|
||||
if created_at_to != -1:
|
||||
filters["createdAtTo"] = created_at_to
|
||||
if completed_at_from != -1:
|
||||
filters["completedAtFrom"] = completed_at_from
|
||||
if completed_at_to != -1:
|
||||
filters["completedAtTo"] = completed_at_to
|
||||
if is_ai_suggested != -1:
|
||||
filters["isAiSuggested"] = is_ai_suggested
|
||||
|
||||
result = await execute_on_client(action="count", table="tasks", filters=filters)
|
||||
return f"Task count: {result.get('count', 0)}"
|
||||
|
||||
|
||||
@tool
|
||||
async def create_task(
|
||||
title: str,
|
||||
@@ -72,6 +181,8 @@ async def create_task(
|
||||
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||
project_id: optional UUID of the parent project
|
||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||
|
||||
completedAt is set automatically when status is 'done'.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
@@ -90,7 +201,7 @@ async def create_task(
|
||||
row = result["row"]
|
||||
return (
|
||||
f"Task created: '{row['title']}' "
|
||||
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
|
||||
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']}, projectId: {row.get('projectId')})"
|
||||
)
|
||||
|
||||
|
||||
@@ -108,6 +219,10 @@ async def update_task(
|
||||
"""Update fields on an existing task. Only pass fields you want to change.
|
||||
task_id: the task's UUID (required)
|
||||
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||
|
||||
completedAt is managed automatically:
|
||||
- setting status to 'done' records the current timestamp
|
||||
- changing status away from 'done' clears completedAt
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
@@ -130,7 +245,7 @@ async def update_task(
|
||||
data={"id": task_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
|
||||
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']}, projectId: {row.get('projectId')})"
|
||||
|
||||
|
||||
@tool
|
||||
@@ -141,21 +256,36 @@ async def delete_task(task_id: str) -> str:
|
||||
|
||||
|
||||
@tool
|
||||
async def list_tasks_due_today() -> str:
|
||||
"""List all tasks whose due date falls on today's date."""
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
|
||||
end_ms = start_ms + 86_400_000 - 1 # last ms of today
|
||||
async def list_tasks_due_today(user_timezone: str = "UTC", include_done: bool = False) -> str:
|
||||
"""List all tasks whose due date falls on today's date.
|
||||
|
||||
user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York').
|
||||
Always pass the user's timezone so 'today' is computed in their local time.
|
||||
include_done: set True to also include already-completed tasks due today (default False).
|
||||
"""
|
||||
try:
|
||||
from zoneinfo import ZoneInfo
|
||||
tz = ZoneInfo(user_timezone or "UTC")
|
||||
except Exception:
|
||||
tz = timezone.utc
|
||||
now_local = datetime.now(tz=tz)
|
||||
start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz)
|
||||
start_ms = int(start_dt.timestamp() * 1000)
|
||||
end_ms = start_ms + 86_400_000 - 1
|
||||
filters: dict[str, Any] = {"dueDateFrom": start_ms, "dueDateTo": end_ms}
|
||||
if not include_done:
|
||||
filters["status"] = "todo"
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="tasks",
|
||||
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
|
||||
filters=filters,
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No tasks are due today."
|
||||
lines = [
|
||||
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
|
||||
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, "
|
||||
f"projectId: {r.get('projectId')}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
|
||||
@@ -193,7 +323,6 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
||||
)
|
||||
row = result.get("row", {})
|
||||
row_author = row.get("author", author)
|
||||
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
|
||||
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
||||
row_comment_id = row.get("id", "unknown")
|
||||
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
||||
@@ -211,6 +340,7 @@ async def delete_task_comment(comment_id: str) -> str:
|
||||
|
||||
TASK_TOOLS: list[Any] = [
|
||||
list_tasks,
|
||||
count_tasks,
|
||||
create_task,
|
||||
update_task,
|
||||
delete_task,
|
||||
@@ -219,3 +349,10 @@ TASK_TOOLS: list[Any] = [
|
||||
add_task_comment,
|
||||
delete_task_comment,
|
||||
]
|
||||
|
||||
TASK_READ_TOOLS: list[Any] = [
|
||||
list_tasks,
|
||||
count_tasks,
|
||||
list_tasks_due_today,
|
||||
list_task_comments,
|
||||
]
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
@@ -19,19 +20,128 @@ def _is_uuid(value: str) -> bool:
|
||||
|
||||
|
||||
@tool
|
||||
async def list_timelines(project_id: str = "") -> str:
|
||||
"""List timelines. Provide project_id to scope to a specific project."""
|
||||
async def list_timelines(
|
||||
project_id: str = "",
|
||||
type: str = "",
|
||||
is_completed: int = -1,
|
||||
is_ai_suggested: int = -1,
|
||||
order_by: str = "",
|
||||
order_dir: str = "",
|
||||
date_from: int = -1,
|
||||
date_to: int = -1,
|
||||
created_at_from: int = -1,
|
||||
created_at_to: int = -1,
|
||||
completed_at_from: int = -1,
|
||||
completed_at_to: int = -1,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> str:
|
||||
"""List timeline events (milestones, checkpoints, activities) with optional filters.
|
||||
|
||||
project_id: UUID to scope results to a specific project.
|
||||
type: filter by event type — milestone | checkpoint | activity.
|
||||
is_completed: 0 = incomplete only, 1 = completed only, -1 = any (default).
|
||||
is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any.
|
||||
order_by: sort field — date (default) | createdAt | completedAt.
|
||||
order_dir: asc (default) | desc.
|
||||
date_from / date_to: ms epoch range for the event date. Use -1 to omit.
|
||||
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
|
||||
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
||||
limit: max rows to return (default 50). Use with offset to paginate.
|
||||
offset: skip first N rows (default 0).
|
||||
|
||||
Tip — combine *_from and *_to for a closed range; pass only one for open-ended.
|
||||
Tip — prefer count_timelines for "how many" questions to avoid listing rows.
|
||||
Tip — for natural-language windows ("today", "this week", "last month", etc.)
|
||||
take date_from / date_to verbatim from the DATE CONTEXT block in the system prompt;
|
||||
do not compute boundaries from the current UTC instant.
|
||||
"""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="timelines",
|
||||
filters={"projectId": normalized_project_id or None},
|
||||
)
|
||||
filters: dict[str, Any] = {
|
||||
"projectId": normalized_project_id or None,
|
||||
"orderBy": order_by or None,
|
||||
"orderDir": order_dir or None,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
if type:
|
||||
filters["type"] = type
|
||||
if is_completed != -1:
|
||||
filters["isCompleted"] = is_completed
|
||||
if is_ai_suggested != -1:
|
||||
filters["isAiSuggested"] = is_ai_suggested
|
||||
if date_from != -1:
|
||||
filters["dateFrom"] = date_from
|
||||
if date_to != -1:
|
||||
filters["dateTo"] = date_to
|
||||
if created_at_from != -1:
|
||||
filters["createdAtFrom"] = created_at_from
|
||||
if created_at_to != -1:
|
||||
filters["createdAtTo"] = created_at_to
|
||||
if completed_at_from != -1:
|
||||
filters["completedAtFrom"] = completed_at_from
|
||||
if completed_at_to != -1:
|
||||
filters["completedAtTo"] = completed_at_to
|
||||
|
||||
result = await execute_on_client(action="select", table="timelines", filters=filters)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No timelines found."
|
||||
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
|
||||
return "No timeline events found."
|
||||
lines = [
|
||||
f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, "
|
||||
f"completed: {bool(r.get('isCompleted'))}, completedAt: {r.get('completedAt')}, "
|
||||
f"projectId: {r.get('projectId')}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} timeline event(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def count_timelines(
|
||||
project_id: str = "",
|
||||
type: str = "",
|
||||
is_completed: int = -1,
|
||||
is_ai_suggested: int = -1,
|
||||
date_from: int = -1,
|
||||
date_to: int = -1,
|
||||
created_at_from: int = -1,
|
||||
created_at_to: int = -1,
|
||||
completed_at_from: int = -1,
|
||||
completed_at_to: int = -1,
|
||||
) -> str:
|
||||
"""Count timeline events matching the given filters without returning rows.
|
||||
|
||||
Use this instead of list_timelines for "how many" questions — it is much cheaper.
|
||||
Same filter parameters as list_timelines (no limit/offset/order_by needed).
|
||||
|
||||
date_from / date_to: ms epoch range for the event date. Use -1 to omit.
|
||||
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
||||
Tip — for natural-language windows take date_from / date_to from the DATE CONTEXT block;
|
||||
do not compute boundaries from the current UTC instant.
|
||||
"""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
filters: dict[str, Any] = {"projectId": normalized_project_id or None}
|
||||
if type:
|
||||
filters["type"] = type
|
||||
if is_completed != -1:
|
||||
filters["isCompleted"] = is_completed
|
||||
if is_ai_suggested != -1:
|
||||
filters["isAiSuggested"] = is_ai_suggested
|
||||
if date_from != -1:
|
||||
filters["dateFrom"] = date_from
|
||||
if date_to != -1:
|
||||
filters["dateTo"] = date_to
|
||||
if created_at_from != -1:
|
||||
filters["createdAtFrom"] = created_at_from
|
||||
if created_at_to != -1:
|
||||
filters["createdAtTo"] = created_at_to
|
||||
if completed_at_from != -1:
|
||||
filters["completedAtFrom"] = completed_at_from
|
||||
if completed_at_to != -1:
|
||||
filters["completedAtTo"] = completed_at_to
|
||||
|
||||
result = await execute_on_client(action="count", table="timelines", filters=filters)
|
||||
return f"Timeline event count: {result.get('count', 0)}"
|
||||
|
||||
|
||||
@tool
|
||||
@@ -39,13 +149,19 @@ async def create_timeline(
|
||||
project_id: str,
|
||||
title: str,
|
||||
date: int,
|
||||
type: str = "milestone",
|
||||
is_completed: int = 0,
|
||||
is_ai_suggested: int = 0,
|
||||
) -> str:
|
||||
"""Create a project timeline (milestone).
|
||||
"""Create a project timeline event.
|
||||
project_id: REQUIRED UUID of the parent project
|
||||
title: descriptive name for the milestone
|
||||
date: Unix timestamp in milliseconds
|
||||
title: descriptive name for the event
|
||||
date: Unix timestamp in milliseconds for the event date
|
||||
type: milestone (default) | checkpoint | activity
|
||||
is_completed: 1 if already completed, 0 if not (default 0)
|
||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||
|
||||
completedAt is set automatically when is_completed is 1.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
@@ -54,11 +170,13 @@ async def create_timeline(
|
||||
"projectId": project_id,
|
||||
"title": title,
|
||||
"date": date,
|
||||
"type": type,
|
||||
"isCompleted": is_completed,
|
||||
"isAiSuggested": is_ai_suggested,
|
||||
},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
|
||||
return f"Timeline event created: '{row['title']}' (id: {row['id']}, date: {row['date']}, type: {row.get('type')})"
|
||||
|
||||
|
||||
@tool
|
||||
@@ -66,35 +184,87 @@ async def update_timeline(
|
||||
timeline_id: str,
|
||||
title: str = "",
|
||||
date: int = -1,
|
||||
is_completed: int = -1,
|
||||
) -> str:
|
||||
"""Update a timeline. Only pass fields that should change.
|
||||
timeline_id: UUID of the timeline (required)
|
||||
"""Update a timeline event. Only pass fields that should change.
|
||||
timeline_id: UUID of the event (required)
|
||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||
is_completed: 0 = mark incomplete, 1 = mark complete, -1 = unchanged
|
||||
|
||||
completedAt is managed automatically:
|
||||
- setting is_completed to 1 records the current timestamp
|
||||
- setting is_completed to 0 clears completedAt
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if date != -1:
|
||||
updates["date"] = date
|
||||
if is_completed != -1:
|
||||
updates["isCompleted"] = is_completed
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="timelines",
|
||||
data={"id": timeline_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
|
||||
return f"Timeline event updated: '{row['title']}' (id: {row['id']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_timeline(timeline_id: str) -> str:
|
||||
"""Delete a timeline permanently by its UUID."""
|
||||
"""Delete a timeline event permanently by its UUID."""
|
||||
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
|
||||
return f"Timeline {timeline_id} deleted."
|
||||
return f"Timeline event {timeline_id} deleted."
|
||||
|
||||
|
||||
@tool
|
||||
async def list_timelines_today(user_timezone: str = "UTC", include_completed: bool = True) -> str:
|
||||
"""List all timeline events whose date falls on today.
|
||||
|
||||
user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York').
|
||||
Always pass the user's timezone so 'today' is computed in their local time.
|
||||
include_completed: set False to exclude already-completed events (default True).
|
||||
"""
|
||||
try:
|
||||
from zoneinfo import ZoneInfo
|
||||
tz = ZoneInfo(user_timezone or "UTC")
|
||||
except Exception:
|
||||
tz = timezone.utc
|
||||
now_local = datetime.now(tz=tz)
|
||||
start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz)
|
||||
start_ms = int(start_dt.timestamp() * 1000)
|
||||
end_ms = start_ms + 86_400_000 - 1
|
||||
filters: dict[str, Any] = {"dateFrom": start_ms, "dateTo": end_ms}
|
||||
if not include_completed:
|
||||
filters["isCompleted"] = 0
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="timelines",
|
||||
filters=filters,
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No timeline events today."
|
||||
lines = [
|
||||
f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, "
|
||||
f"completed: {bool(r.get('isCompleted'))}, projectId: {r.get('projectId')}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Timeline events today ({len(rows)}):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
TIMELINE_TOOLS: list[Any] = [
|
||||
list_timelines,
|
||||
count_timelines,
|
||||
list_timelines_today,
|
||||
create_timeline,
|
||||
update_timeline,
|
||||
delete_timeline,
|
||||
]
|
||||
|
||||
TIMELINE_READ_TOOLS: list[Any] = [
|
||||
list_timelines,
|
||||
count_timelines,
|
||||
list_timelines_today,
|
||||
]
|
||||
|
||||
@@ -16,16 +16,17 @@ import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.tier_manager import FEATURES
|
||||
from app.core.agent_runner import is_agent_running, run_local_agent
|
||||
from app.core.device_manager import device_manager
|
||||
from app.core.note_summarizer import generate_note_summary
|
||||
from app.db import get_session
|
||||
from app.models import AgentRunLog, LocalAgentConfig
|
||||
from app.schemas import (
|
||||
@@ -37,6 +38,8 @@ from app.schemas import (
|
||||
UserProfile,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/agents", tags=["agents"])
|
||||
|
||||
|
||||
@@ -230,3 +233,25 @@ async def trigger_agent_run(
|
||||
)
|
||||
|
||||
return _to_run_log_response(run_log)
|
||||
|
||||
|
||||
# ── Note summary endpoint ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class NoteSummarizeRequest(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
|
||||
|
||||
class NoteSummarizeResponse(BaseModel):
|
||||
summary: str
|
||||
|
||||
|
||||
@router.post("/notes/summarize", response_model=NoteSummarizeResponse)
|
||||
async def summarize_note(
|
||||
body: NoteSummarizeRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> NoteSummarizeResponse:
|
||||
"""Generate an AI summary for a note. Used by the Electron backfill on startup."""
|
||||
summary = await generate_note_summary(body.title, body.content)
|
||||
return NoteSummarizeResponse(summary=summary)
|
||||
|
||||
@@ -9,7 +9,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Request, status
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -96,3 +96,37 @@ async def list_invoices(
|
||||
"""
|
||||
invoices = await stripe_service.list_invoices(current_user.id, db)
|
||||
return invoices
|
||||
|
||||
|
||||
# ── Quota check ────────────────────────────────────────────────────────
|
||||
|
||||
from app.billing.quota import check_folder_quota, QuotaExceeded # noqa: E402
|
||||
|
||||
|
||||
class QuotaCheckRequest(BaseModel):
|
||||
feature: str
|
||||
estimated_files: int
|
||||
|
||||
|
||||
@router.post("/quota/check")
|
||||
async def quota_check(
|
||||
payload: QuotaCheckRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict:
|
||||
"""Pre-flight folder quota check. 402 if tier limits would be exceeded."""
|
||||
if payload.feature != "folder_index":
|
||||
raise HTTPException(status_code=400, detail="Unknown feature")
|
||||
try:
|
||||
await check_folder_quota(
|
||||
user_id=current_user.id,
|
||||
tier=current_user.tier,
|
||||
estimated_files=payload.estimated_files,
|
||||
db=db,
|
||||
)
|
||||
except QuotaExceeded as exc:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={"reason": exc.reason, "message": str(exc)},
|
||||
)
|
||||
return {"ok": True}
|
||||
|
||||
@@ -5,13 +5,19 @@ WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.brief_agent import run_home_brief, run_project_brief
|
||||
from app.core.deep_agent import run_home
|
||||
from app.core.llm import embed
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import async_session
|
||||
from app.schemas import ChatRequest, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
@@ -45,6 +51,57 @@ async def chat(
|
||||
return JSONResponse(content={"response": response})
|
||||
|
||||
|
||||
class _BriefRequest(BaseModel):
|
||||
mode: Literal["home", "project"]
|
||||
project_id: str | None = None
|
||||
|
||||
|
||||
class _BriefResponse(BaseModel):
|
||||
response: str
|
||||
|
||||
|
||||
@router.post("/brief", response_model=_BriefResponse)
|
||||
async def brief(
|
||||
body: _BriefRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> _BriefResponse:
|
||||
"""REST fallback for brief when the device WebSocket is not ready."""
|
||||
if body.mode == "project":
|
||||
if not body.project_id:
|
||||
raise HTTPException(status_code=422, detail="project_id required for project mode")
|
||||
try:
|
||||
uuid.UUID(body.project_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=422, detail="project_id must be a valid UUID")
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
current_user.id,
|
||||
"",
|
||||
trace_id=request_id,
|
||||
session_id=request_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"_debug": {"request_id": request_id, "user_id": current_user.id},
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
chunks: list[str] = []
|
||||
if body.mode == "project":
|
||||
stream = run_project_brief(current_user.id, body.project_id, context) # type: ignore[arg-type]
|
||||
else:
|
||||
stream = run_home_brief(current_user.id, context)
|
||||
|
||||
async for event_type, data in stream:
|
||||
if event_type == "token" and data:
|
||||
chunks.append(str(data))
|
||||
|
||||
return _BriefResponse(response="".join(chunks))
|
||||
|
||||
|
||||
@router.post("/embed", response_model=_EmbedResponse)
|
||||
async def embed_text(
|
||||
body: _EmbedRequest,
|
||||
|
||||
@@ -42,19 +42,27 @@ from sqlalchemy import update
|
||||
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
|
||||
from app.config.settings import settings
|
||||
from app.core.agent_runner import trigger_pending_runs
|
||||
from app.core.deep_agent import run_floating_stream, run_home_stream
|
||||
from app.core.agent_session_buffer import session_buffer
|
||||
from app.core.brief_agent import run_home_brief, run_project_brief
|
||||
from app.core.deep_agent import run_contextual_stream, run_home_stream, run_task_brief_research_stream
|
||||
from app.core.output_formatter import extract_canvas_block
|
||||
from app.core.device_manager import device_manager
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.core.output_formatter import StreamFormatter
|
||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||
from app.db import async_session
|
||||
from app.models import AgentRunLog
|
||||
from app.schemas import WsFrameType
|
||||
from app.schemas import WsFrameType, WsStreamEnd
|
||||
from app.schemas.contextual import ContextualScope, render_scope_block
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
||||
|
||||
# ── v7 folder index session state ─────────────────────────────────────
|
||||
# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled }
|
||||
_index_sessions: dict[str, dict] = {}
|
||||
|
||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
||||
|
||||
@@ -153,9 +161,14 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
||||
_handle_home_request(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.floating_request:
|
||||
elif frame_type == WsFrameType.brief_request:
|
||||
asyncio.create_task(
|
||||
_handle_floating_request(websocket, user_id, frame)
|
||||
_handle_brief_request(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.task_brief_request:
|
||||
asyncio.create_task(
|
||||
_handle_task_brief_request(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.journey_start:
|
||||
@@ -168,6 +181,29 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
||||
_handle_journey_message(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.index_session_start:
|
||||
asyncio.create_task(
|
||||
_handle_index_session_start(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.index_file_batch:
|
||||
asyncio.create_task(
|
||||
_handle_index_file_batch(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.index_session_cancel:
|
||||
await _handle_index_session_cancel(websocket, frame)
|
||||
|
||||
elif frame_type == WsFrameType.contextual_request:
|
||||
asyncio.create_task(
|
||||
_handle_contextual_request(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.contextual_scope_update:
|
||||
asyncio.create_task(
|
||||
_handle_contextual_scope_update(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == "pong":
|
||||
# Heartbeat ack — nothing to do, connection is alive.
|
||||
pass
|
||||
@@ -199,11 +235,13 @@ async def _handle_home_request(
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
message: str = frame.get("message", "")
|
||||
session_id: str = frame.get("session_id") or str(uuid4())
|
||||
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
|
||||
logger.info(
|
||||
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
|
||||
"device_ws: home_request_start user=%s req=%s session=%s project=%s msg=%s",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
project_id,
|
||||
message[:200],
|
||||
)
|
||||
|
||||
@@ -220,6 +258,7 @@ async def _handle_home_request(
|
||||
context: dict = {
|
||||
"conversation_history": frame.get("conversation_history", []),
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
"format_prefs": frame.get("format_prefs"),
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
@@ -227,7 +266,7 @@ async def _handle_home_request(
|
||||
set_client_executor(executor)
|
||||
response_chunks: list[str] = []
|
||||
try:
|
||||
event_stream = run_home_stream(user_id, message, context)
|
||||
event_stream = run_home_stream(user_id, message, context, project_id=project_id)
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
@@ -257,26 +296,41 @@ async def _handle_home_request(
|
||||
)
|
||||
|
||||
|
||||
async def _handle_floating_request(
|
||||
# ── v8 Contextual Sidebar Handlers ───────────────────────────────────
|
||||
|
||||
|
||||
def get_session_buffer(user_id: str, session_id: str, channel: str = "contextual"):
|
||||
"""Return a session-scoped buffer proxy for the given user+session.
|
||||
|
||||
Returns a _ContextualBufferProxy that exposes append_system_message().
|
||||
Defined at module level so tests can monkeypatch it.
|
||||
The channel kwarg is accepted for forward-compatibility.
|
||||
"""
|
||||
from app.core.agent_session_buffer import ContextualBufferProxy # noqa: PLC0415
|
||||
return ContextualBufferProxy(session_buffer, user_id, session_id)
|
||||
|
||||
|
||||
async def _handle_contextual_request(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a floating_request frame — streams FloatingFormatter output back on the socket."""
|
||||
"""Handle a contextual_request frame — runs the contextual agent and streams frames."""
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
message: str = frame.get("message", "")
|
||||
session_id: str = frame.get("session_id") or str(uuid4())
|
||||
scope: dict = frame.get("scope", {})
|
||||
scope_payload: dict = frame.get("scope", {})
|
||||
logger.info(
|
||||
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
|
||||
"device_ws: contextual_request_start user=%s req=%s session=%s msg=%s",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
json.dumps(scope, ensure_ascii=True)[:200],
|
||||
message[:200],
|
||||
)
|
||||
|
||||
# ── Memory: enrich context before LLM call ────────────────────────
|
||||
scope = ContextualScope.model_validate(scope_payload)
|
||||
|
||||
# Enrich context with memory before the LLM call.
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
@@ -287,7 +341,8 @@ async def _handle_floating_request(
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"scope": scope,
|
||||
"conversation_history": frame.get("conversation_history", []),
|
||||
"format_prefs": frame.get("format_prefs"),
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
**memory_context,
|
||||
}
|
||||
@@ -296,7 +351,12 @@ async def _handle_floating_request(
|
||||
set_client_executor(executor)
|
||||
response_chunks: list[str] = []
|
||||
try:
|
||||
event_stream = run_floating_stream(user_id, message, context)
|
||||
event_stream = run_contextual_stream(
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
context=context,
|
||||
scope=scope,
|
||||
)
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
@@ -304,20 +364,20 @@ async def _handle_floating_request(
|
||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: floating_request failed user=%s req=%s: %s",
|
||||
"device_ws: contextual_request failed user=%s req=%s: %s",
|
||||
user_id, request_id, exc,
|
||||
)
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
# ── Memory: store episode after response ──────────────────────────
|
||||
# Store episode so the contextual agent can recall prior turns.
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.store_episode(
|
||||
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||
)
|
||||
logger.info(
|
||||
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
|
||||
"device_ws: contextual_request_end user=%s req=%s session=%s response_chars=%d",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
@@ -325,6 +385,206 @@ async def _handle_floating_request(
|
||||
)
|
||||
|
||||
|
||||
async def _handle_contextual_scope_update(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a contextual_scope_update frame.
|
||||
|
||||
Injects a synthetic system message into the session buffer so the next
|
||||
agent turn knows the user navigated. No LLM call is made.
|
||||
"""
|
||||
session_id: str = frame.get("session_id") or str(uuid4())
|
||||
scope = ContextualScope.model_validate(frame.get("scope", {}))
|
||||
block = render_scope_block(scope)
|
||||
buf = get_session_buffer(user_id, session_id, channel="contextual")
|
||||
buf.append_system_message(
|
||||
f"User navigated to a new view. {block} Treat this as the new active context."
|
||||
)
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.contextual_scope_ack,
|
||||
"session_id": session_id,
|
||||
}))
|
||||
logger.info(
|
||||
"device_ws: contextual_scope_update user=%s session=%s page=%s",
|
||||
user_id, session_id, scope.page,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_brief_request(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a brief_request frame — streams plain-text brief back on the socket.
|
||||
|
||||
No episode storage — briefs are not conversations.
|
||||
"""
|
||||
import uuid as _uuid
|
||||
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
session_id = frame.get("session_id") or str(uuid4())
|
||||
mode: str = frame.get("mode", "home")
|
||||
project_id: str | None = frame.get("project_id")
|
||||
|
||||
logger.info(
|
||||
"device_ws: brief_request_start user=%s req=%s mode=%s project_id=%s",
|
||||
user_id, request_id, mode, project_id,
|
||||
)
|
||||
|
||||
# Validate project_id for project mode before touching LLM.
|
||||
if mode == "project":
|
||||
try:
|
||||
if not project_id:
|
||||
raise ValueError("project_id required for project mode")
|
||||
_uuid.UUID(project_id)
|
||||
except (ValueError, AttributeError) as exc:
|
||||
logger.warning(
|
||||
"device_ws: brief_request invalid project_id user=%s req=%s: %s",
|
||||
user_id, request_id, exc,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
||||
)
|
||||
return
|
||||
|
||||
# Enrich context with memory (no user message — use empty string as probe).
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
user_id,
|
||||
"",
|
||||
trace_id=request_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
"format_prefs": frame.get("format_prefs"),
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
try:
|
||||
if mode == "project":
|
||||
event_stream = run_project_brief(user_id, project_id, context) # type: ignore[arg-type]
|
||||
else:
|
||||
event_stream = run_home_brief(user_id, context)
|
||||
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: brief_request failed user=%s req=%s: %s",
|
||||
user_id, request_id, exc,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
||||
)
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
logger.info(
|
||||
"device_ws: brief_request_end user=%s req=%s mode=%s",
|
||||
user_id, request_id, mode,
|
||||
)
|
||||
|
||||
|
||||
# ── v6 Task Brief Handler ────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _handle_task_brief_request(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a task_brief_request frame — Stage-1 executive assistant deep research.
|
||||
|
||||
Streams the briefing markdown back to the client.
|
||||
On stream_end, emits a ``canvas_draft`` mutation if the agent produced one.
|
||||
"""
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
session_id = frame.get("session_id") or str(uuid4())
|
||||
task_id: str = frame.get("task_id") or frame.get("taskId") or ""
|
||||
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
|
||||
|
||||
logger.info(
|
||||
"device_ws: task_brief_request_start user=%s req=%s task=%s project=%s [cache_miss]",
|
||||
user_id, request_id, task_id, project_id,
|
||||
)
|
||||
|
||||
if not task_id:
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, error="task_id is required").model_dump_json()
|
||||
)
|
||||
return
|
||||
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
user_id,
|
||||
f"task brief: {task_id}",
|
||||
trace_id=request_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
"format_prefs": frame.get("format_prefs"),
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
response_chunks: list[str] = []
|
||||
|
||||
try:
|
||||
event_stream = run_task_brief_research_stream(user_id, task_id, context, project_id=project_id)
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
elif ws_frame.type == "stream_start":
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
# stream_end is emitted below with mutations — skip formatter's version
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: task_brief_request failed user=%s req=%s task=%s: %s",
|
||||
user_id, request_id, task_id, exc,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
||||
)
|
||||
return
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
# Extract canvas block then emit stream_end with optional mutations.
|
||||
full_response = "".join(response_chunks)
|
||||
_visible, canvas_content, canvas_kind = extract_canvas_block(full_response)
|
||||
|
||||
mutations: list[dict] = []
|
||||
if canvas_content:
|
||||
mutations.append({
|
||||
"type": "canvas_draft",
|
||||
"content": canvas_content,
|
||||
"kind": canvas_kind,
|
||||
})
|
||||
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, mutations=mutations or None).model_dump_json()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"device_ws: task_brief_request_end user=%s req=%s task=%s response_chars=%d canvas=%s",
|
||||
user_id, request_id, task_id, len(full_response), canvas_kind or "none",
|
||||
)
|
||||
|
||||
|
||||
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -382,6 +642,174 @@ async def _handle_journey_message(
|
||||
clear_client_executor()
|
||||
|
||||
|
||||
# ── v7 Folder Index Handlers ──────────────────────────────────────────
|
||||
|
||||
|
||||
async def _handle_index_session_start(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Register a new folder index session. No response sent — client is declaring intent."""
|
||||
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||
project_id: str | None = frame.get("projectId") or frame.get("project_id")
|
||||
total: int = int(frame.get("totalFiles") or frame.get("total_files") or 0)
|
||||
|
||||
if not session_id:
|
||||
logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id)
|
||||
return
|
||||
|
||||
_index_sessions[session_id] = {
|
||||
"user_id": user_id,
|
||||
"project_id": project_id,
|
||||
"processed": 0,
|
||||
"total": total,
|
||||
"cancelled": False,
|
||||
}
|
||||
logger.info(
|
||||
"device_ws: index_session_start user=%s session=%s project=%s total=%d",
|
||||
user_id, session_id, project_id, total,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_index_session_cancel(
|
||||
websocket: WebSocket,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Mark a session as cancelled and emit index_session_done(cancelled)."""
|
||||
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||
session = _index_sessions.get(session_id)
|
||||
if session:
|
||||
session["cancelled"] = True
|
||||
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_done,
|
||||
"sessionId": session_id,
|
||||
"status": "cancelled",
|
||||
}))
|
||||
_index_sessions.pop(session_id, None)
|
||||
logger.info("device_ws: index_session_cancel session=%s", session_id)
|
||||
|
||||
|
||||
async def _handle_index_file_batch(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Process a batch of files for an index session, streaming results back."""
|
||||
# Lazy imports to avoid heavy load at module startup.
|
||||
from app.core.folder_indexer import ( # noqa: PLC0415
|
||||
summarize_image,
|
||||
summarize_pdf,
|
||||
summarize_docx,
|
||||
summarize_text,
|
||||
)
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
from app.billing.quota import add_token_usage # noqa: PLC0415
|
||||
|
||||
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||
files: list[dict] = frame.get("files", [])
|
||||
|
||||
session = _index_sessions.get(session_id)
|
||||
if not session or session.get("cancelled"):
|
||||
return
|
||||
|
||||
async with async_session() as db:
|
||||
tier = await tier_manager.get_tier(user_id, db)
|
||||
raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
||||
cap: int | None = None if raw_cap == -1 else raw_cap
|
||||
|
||||
for file_info in files:
|
||||
if session.get("cancelled"):
|
||||
return
|
||||
|
||||
# Electron's toSnakeCase converts payload keys, so accept both forms.
|
||||
rel_path: str = file_info.get("relPath") or file_info.get("rel_path") or ""
|
||||
kind: str = file_info.get("kind") or "text"
|
||||
content: str = file_info.get("content") or ""
|
||||
ext: str = file_info.get("ext") or ""
|
||||
mime: str = file_info.get("mime") or "application/octet-stream"
|
||||
name: str = rel_path.split("/")[-1] or rel_path
|
||||
|
||||
try:
|
||||
if kind == "image":
|
||||
res = await summarize_image(image_b64=content, mime=mime)
|
||||
elif kind == "pdf":
|
||||
res = await summarize_pdf(pdf_b64=content, name=name)
|
||||
elif kind == "docx":
|
||||
res = await summarize_docx(docx_b64=content, name=name)
|
||||
else:
|
||||
res = await summarize_text(content=content, ext=ext, name=name)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"device_ws: index_file_batch summarize failed session=%s path=%s: %s",
|
||||
session_id, rel_path, exc,
|
||||
)
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_file_result,
|
||||
"sessionId": session_id,
|
||||
"relPath": rel_path,
|
||||
"summary": None,
|
||||
"tokensUsed": 0,
|
||||
"error": str(exc),
|
||||
}))
|
||||
session["processed"] += 1
|
||||
continue
|
||||
|
||||
# Account for token usage and check cap.
|
||||
usage = await add_token_usage(
|
||||
user_id=user_id,
|
||||
feature="folder_index",
|
||||
tokens=res.tokens_used,
|
||||
db=db,
|
||||
cap=cap,
|
||||
)
|
||||
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_file_result,
|
||||
"sessionId": session_id,
|
||||
"relPath": rel_path,
|
||||
"summary": res.summary,
|
||||
"tokensUsed": res.tokens_used,
|
||||
}))
|
||||
session["processed"] += 1
|
||||
|
||||
if usage.exhausted:
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_done,
|
||||
"sessionId": session_id,
|
||||
"status": "quota_exceeded",
|
||||
}))
|
||||
_index_sessions.pop(session_id, None)
|
||||
logger.info(
|
||||
"device_ws: index_session quota_exceeded user=%s session=%s",
|
||||
user_id, session_id,
|
||||
)
|
||||
return
|
||||
|
||||
# After processing the batch, emit progress.
|
||||
processed = session["processed"]
|
||||
total = session["total"]
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_progress,
|
||||
"sessionId": session_id,
|
||||
"processed": processed,
|
||||
"total": total,
|
||||
}))
|
||||
|
||||
if processed >= total:
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_done,
|
||||
"sessionId": session_id,
|
||||
"status": "completed",
|
||||
}))
|
||||
_index_sessions.pop(session_id, None)
|
||||
logger.info(
|
||||
"device_ws: index_session_done completed user=%s session=%s processed=%d",
|
||||
user_id, session_id, processed,
|
||||
)
|
||||
|
||||
|
||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||
|
||||
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||
|
||||
225
app/api/routes/memory.py
Normal file
225
app/api/routes/memory.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Memory management routes — view/edit/delete user memory tiers.
|
||||
|
||||
All routes require authentication. Data is always user-scoped.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.models import (
|
||||
ExtractionQueue,
|
||||
MemoryAssociative,
|
||||
MemoryCore,
|
||||
MemoryEpisodic,
|
||||
MemoryProactive,
|
||||
MemoryRelation,
|
||||
)
|
||||
from app.schemas import UserProfile
|
||||
|
||||
router = APIRouter(prefix="/memory", tags=["memory"])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ALLOWED_PREDICATES = {
|
||||
"works_at",
|
||||
"reports_to",
|
||||
"stakeholder_of",
|
||||
"last_contacted_on",
|
||||
"owes_followup",
|
||||
"manages",
|
||||
"collaborates_with",
|
||||
"owns",
|
||||
"member_of",
|
||||
"custom",
|
||||
}
|
||||
|
||||
|
||||
# ── Response schemas ─────────────────────────────────────────────────────────
|
||||
|
||||
class RelationOut(BaseModel):
|
||||
id: str
|
||||
subject_label: str
|
||||
subject_type: str
|
||||
predicate: str
|
||||
object_label: str
|
||||
object_type: str
|
||||
confidence: float
|
||||
last_confirmed_at: int | None = None # epoch ms
|
||||
|
||||
|
||||
class RelationPatch(BaseModel):
|
||||
subject_label: str | None = None
|
||||
object_label: str | None = None
|
||||
predicate: str | None = None
|
||||
confidence: float | None = Field(None, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class CoreAddBody(BaseModel):
|
||||
key: str = Field(..., min_length=1, max_length=255)
|
||||
value: str = Field(..., min_length=1)
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _relation_to_out(row: MemoryRelation) -> RelationOut:
|
||||
last_ms: int | None = None
|
||||
if row.last_confirmed_at is not None:
|
||||
last_ms = int(row.last_confirmed_at.timestamp() * 1000)
|
||||
return RelationOut(
|
||||
id=row.id,
|
||||
subject_label=row.subject_label,
|
||||
subject_type=row.subject_type,
|
||||
predicate=row.predicate,
|
||||
object_label=row.object_label,
|
||||
object_type=row.object_type,
|
||||
confidence=row.confidence,
|
||||
last_confirmed_at=last_ms,
|
||||
)
|
||||
|
||||
|
||||
# ── Routes ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/core", response_model=dict[str, str])
|
||||
async def get_core_memory(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Return all core memory k/v pairs (plaintext) for the current user."""
|
||||
mw = MemoryMiddleware(db)
|
||||
blocks = await mw.list_core_blocks(current_user.id)
|
||||
return {b["label"]: b["value"] for b in blocks}
|
||||
|
||||
|
||||
@router.delete("/core/{key}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_core_key(
|
||||
key: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
"""Delete a single core memory key (GDPR Art. 17)."""
|
||||
mw = MemoryMiddleware(db)
|
||||
deleted = await mw.delete_core(current_user.id, key)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Key not found")
|
||||
|
||||
|
||||
@router.post("/core", status_code=status.HTTP_201_CREATED, response_model=dict[str, str])
|
||||
async def add_core_key(
|
||||
body: CoreAddBody,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Add or overwrite a core memory key/value pair."""
|
||||
mw = MemoryMiddleware(db)
|
||||
await mw.update_core(current_user.id, body.key, body.value)
|
||||
return {body.key: body.value}
|
||||
|
||||
|
||||
@router.get("/relational", response_model=list[RelationOut])
|
||||
async def get_relational_memory(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[RelationOut]:
|
||||
"""Return all relational memory rows for the current user."""
|
||||
mw = MemoryMiddleware(db)
|
||||
rows = await mw.query_relations(current_user.id, limit=200)
|
||||
return [_relation_to_out(r) for r in rows]
|
||||
|
||||
|
||||
@router.patch("/relational/{relation_id}", response_model=RelationOut)
|
||||
async def patch_relation(
|
||||
relation_id: str,
|
||||
body: RelationPatch,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> RelationOut:
|
||||
"""Edit a relation row's labels, predicate, or confidence."""
|
||||
if body.predicate is not None and body.predicate not in _ALLOWED_PREDICATES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"predicate must be one of: {sorted(_ALLOWED_PREDICATES)}",
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
select(MemoryRelation).where(
|
||||
MemoryRelation.id == relation_id,
|
||||
MemoryRelation.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
||||
|
||||
if body.subject_label is not None:
|
||||
row.subject_label = body.subject_label
|
||||
if body.object_label is not None:
|
||||
row.object_label = body.object_label
|
||||
if body.predicate is not None:
|
||||
row.predicate = body.predicate
|
||||
if body.confidence is not None:
|
||||
row.confidence = body.confidence
|
||||
row.last_confirmed_at = datetime.now(timezone.utc)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(row)
|
||||
logger.info("memory: patch_relation user=%s relation=%s", current_user.id, relation_id)
|
||||
return _relation_to_out(row)
|
||||
|
||||
|
||||
@router.delete("/relational/{relation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_relation(
|
||||
relation_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
"""Hard-delete a relation row (GDPR Art. 17)."""
|
||||
result = await db.execute(
|
||||
select(MemoryRelation).where(
|
||||
MemoryRelation.id == relation_id,
|
||||
MemoryRelation.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
||||
await db.delete(row)
|
||||
await db.commit()
|
||||
logger.info("memory: delete_relation user=%s relation=%s", current_user.id, relation_id)
|
||||
|
||||
|
||||
@router.post("/forget-all", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def forget_all(
|
||||
x_confirm: Annotated[str | None, Header(alias="X-Confirm")] = None,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
"""Wipe all memory tiers for the current user (GDPR Art. 17).
|
||||
|
||||
Requires ``X-Confirm: true`` header. Does NOT delete the user account.
|
||||
"""
|
||||
if x_confirm != "true":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing or invalid X-Confirm header. Send X-Confirm: true to confirm.",
|
||||
)
|
||||
|
||||
uid = current_user.id
|
||||
await db.execute(delete(MemoryCore).where(MemoryCore.user_id == uid))
|
||||
await db.execute(delete(MemoryAssociative).where(MemoryAssociative.user_id == uid))
|
||||
await db.execute(delete(MemoryEpisodic).where(MemoryEpisodic.user_id == uid))
|
||||
await db.execute(delete(MemoryProactive).where(MemoryProactive.user_id == uid))
|
||||
await db.execute(delete(MemoryRelation).where(MemoryRelation.user_id == uid))
|
||||
await db.execute(delete(ExtractionQueue).where(ExtractionQueue.user_id == uid))
|
||||
await db.commit()
|
||||
logger.warning("memory: forget_all GDPR wipe user=%s", uid)
|
||||
139
app/billing/quota.py
Normal file
139
app/billing/quota.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Quota checks and atomic token-usage accounting for folder integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.billing.tier_manager import TierManager
|
||||
from app.models import MonthlyTokenUsage
|
||||
from app.schemas import BillingTier
|
||||
|
||||
|
||||
class QuotaExceeded(Exception):
|
||||
"""Raised when a folder operation cannot proceed under the user's tier."""
|
||||
|
||||
def __init__(self, reason: str, message: str) -> None:
|
||||
super().__init__(message)
|
||||
self.reason = reason # "max_files" | "monthly_tokens"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenUsageResult:
|
||||
tokens_used: int
|
||||
exhausted: bool
|
||||
|
||||
|
||||
def _current_year_month() -> str:
|
||||
return datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
|
||||
|
||||
_tier_manager = TierManager()
|
||||
|
||||
|
||||
async def check_folder_quota(
|
||||
*,
|
||||
user_id: str,
|
||||
tier: BillingTier,
|
||||
estimated_files: int,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Raise QuotaExceeded if folder_max_files or folder_monthly_tokens
|
||||
would be violated. -1 in either feature means unlimited."""
|
||||
max_files = _tier_manager.get_feature_value(tier, "folder_max_files")
|
||||
if max_files != -1 and estimated_files > max_files:
|
||||
raise QuotaExceeded(
|
||||
"max_files",
|
||||
f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.",
|
||||
)
|
||||
|
||||
cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
||||
if cap == -1:
|
||||
return
|
||||
ym = _current_year_month()
|
||||
row = (
|
||||
await db.execute(
|
||||
select(MonthlyTokenUsage).where(
|
||||
MonthlyTokenUsage.user_id == user_id,
|
||||
MonthlyTokenUsage.year_month == ym,
|
||||
MonthlyTokenUsage.feature == "folder_index",
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
used = row.tokens_used if row else 0
|
||||
if used >= cap:
|
||||
raise QuotaExceeded(
|
||||
"monthly_tokens",
|
||||
f"Monthly token budget exhausted ({used}/{cap}); resets next month.",
|
||||
)
|
||||
|
||||
|
||||
async def add_token_usage(
|
||||
*,
|
||||
user_id: str,
|
||||
feature: str,
|
||||
tokens: int,
|
||||
db: AsyncSession,
|
||||
cap: int | None = None,
|
||||
) -> TokenUsageResult:
|
||||
"""Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature).
|
||||
|
||||
Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls
|
||||
back to a read-then-write on other engines (e.g. aiosqlite in tests).
|
||||
Returns post-update total and whether cap is exhausted.
|
||||
"""
|
||||
ym = _current_year_month()
|
||||
|
||||
# Detect dialect to choose between native upsert and portable fallback.
|
||||
dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr]
|
||||
|
||||
if dialect_name == "postgresql":
|
||||
# Native atomic upsert — production path.
|
||||
stmt = (
|
||||
pg_insert(MonthlyTokenUsage)
|
||||
.values(
|
||||
user_id=user_id,
|
||||
year_month=ym,
|
||||
feature=feature,
|
||||
tokens_used=tokens,
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["user_id", "year_month", "feature"],
|
||||
set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens},
|
||||
)
|
||||
.returning(MonthlyTokenUsage.tokens_used)
|
||||
)
|
||||
used: int = (await db.execute(stmt)).scalar_one()
|
||||
await db.commit()
|
||||
else:
|
||||
# Portable fallback — used in tests (SQLite) and any non-PG engine.
|
||||
row = (
|
||||
await db.execute(
|
||||
select(MonthlyTokenUsage).where(
|
||||
MonthlyTokenUsage.user_id == user_id,
|
||||
MonthlyTokenUsage.year_month == ym,
|
||||
MonthlyTokenUsage.feature == feature,
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if row is None:
|
||||
row = MonthlyTokenUsage(
|
||||
user_id=user_id,
|
||||
year_month=ym,
|
||||
feature=feature,
|
||||
tokens_used=tokens,
|
||||
)
|
||||
db.add(row)
|
||||
else:
|
||||
row.tokens_used += tokens
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(row)
|
||||
used = row.tokens_used
|
||||
|
||||
exhausted = cap is not None and cap != -1 and used >= cap
|
||||
return TokenUsageResult(tokens_used=used, exhausted=exhausted)
|
||||
@@ -25,6 +25,12 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"providers": 1,
|
||||
"batch_builder": False,
|
||||
"sso": False,
|
||||
"real_embeddings": False, # keyword fallback only
|
||||
"realtime_extraction": False, # batch queue (Phase 2)
|
||||
"relational_memory": False, # relational tier (Phase 3) — Pro+
|
||||
"proactive_mining": False, # Power+ only (Phase 5)
|
||||
"folder_max_files": 200,
|
||||
"folder_monthly_tokens": 100_000,
|
||||
},
|
||||
"pro": {
|
||||
"agents": -1, # unlimited
|
||||
@@ -33,6 +39,12 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"providers": -1,
|
||||
"batch_builder": False,
|
||||
"sso": False,
|
||||
"real_embeddings": True, # pgvector cosine search
|
||||
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
||||
"relational_memory": True, # person/project predicates
|
||||
"proactive_mining": False, # Power+ only (Phase 5)
|
||||
"folder_max_files": 5000,
|
||||
"folder_monthly_tokens": 2_000_000,
|
||||
},
|
||||
"power": {
|
||||
"agents": -1,
|
||||
@@ -41,6 +53,12 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"providers": -1,
|
||||
"batch_builder": True,
|
||||
"sso": False,
|
||||
"real_embeddings": True,
|
||||
"realtime_extraction": True,
|
||||
"relational_memory": True, # all predicates incl. custom
|
||||
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||
"folder_max_files": -1, # unlimited
|
||||
"folder_monthly_tokens": -1, # unlimited
|
||||
},
|
||||
"team": {
|
||||
"agents": -1,
|
||||
@@ -49,6 +67,12 @@ FEATURES: dict[str, dict[str, Any]] = {
|
||||
"providers": -1,
|
||||
"batch_builder": True,
|
||||
"sso": True,
|
||||
"real_embeddings": True,
|
||||
"realtime_extraction": True,
|
||||
"relational_memory": True, # all predicates incl. custom
|
||||
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||
"folder_max_files": -1, # unlimited
|
||||
"folder_monthly_tokens": -1, # unlimited
|
||||
},
|
||||
}
|
||||
|
||||
@@ -107,6 +131,13 @@ class TierManager:
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
def get_feature_value(self, tier: BillingTier, feature: str) -> int:
|
||||
"""Return integer feature value for tier. -1 means unlimited."""
|
||||
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||
if not isinstance(value, int):
|
||||
return 0
|
||||
return value
|
||||
|
||||
# ── Rate limiting ────────────────────────────────────────────────────
|
||||
|
||||
def get_rate_limit(self, tier: BillingTier) -> int:
|
||||
|
||||
@@ -16,17 +16,23 @@ class Settings(BaseSettings):
|
||||
ANTHROPIC_API_KEY: str = ""
|
||||
GOOGLE_API_KEY: str = ""
|
||||
CEREBRAS_API_KEY: str = ""
|
||||
GROQ_API_KEY: str = ""
|
||||
DEEPSEEK_API_KEY: str = ""
|
||||
|
||||
LLM_MODEL: str = "gpt-4o"
|
||||
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||
|
||||
# Per-agent model overrides. Leave empty to fall back to LLM_MODEL.
|
||||
LLM_MODEL_CLASSIFIER: str = "" # _infer_floating_domain (intent routing)
|
||||
LLM_MODEL_CLASSIFIER: str = "" # classifier (intent routing, future use)
|
||||
LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream)
|
||||
LLM_MODEL_FLOATING_AGENT: str = "" # floating-agent (contextual chat)
|
||||
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
|
||||
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
||||
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
||||
LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs)
|
||||
LLM_MODEL_TASK_BRIEF_AGENT: str = "" # task-brief-agent (per-task deep research)
|
||||
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
||||
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
|
||||
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)
|
||||
LLM_MODEL_MEMORY_AUDITOR: str = "" # memory-auditor (Phase 7 weekly audit)
|
||||
|
||||
# GitHub Copilot OAuth token storage directory.
|
||||
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||
@@ -69,9 +75,11 @@ class Settings(BaseSettings):
|
||||
LANGFUSE_PUBLIC_KEY: str = ""
|
||||
LANGFUSE_BASE_URL: str = "https://cloud.langfuse.com"
|
||||
|
||||
SCHEDULER_ENABLED: bool = True
|
||||
|
||||
ENV: Literal["dev", "prod"] = "dev"
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -287,7 +287,6 @@ async def _run_agent_with_tools(
|
||||
return final_text
|
||||
|
||||
for call in response.tool_calls:
|
||||
call_id = str(call.get("id", ""))
|
||||
call_name = str(call.get("name", ""))
|
||||
call_args = call.get("args", {})
|
||||
logger.info(
|
||||
@@ -659,9 +658,14 @@ async def run_local_agent(
|
||||
# ── Phase B: single LLM call ─────────────────────────
|
||||
extraction_rules = _get_extraction_rules(agent_config, content_type)
|
||||
no_match_behavior = _get_no_match_behavior(agent_config)
|
||||
global_rules_lines = "\n".join(
|
||||
f"- {r}" for r in agent_config.get("global_rules", [])
|
||||
)
|
||||
base_global_rules = list(agent_config.get("global_rules", []))
|
||||
if "notes" in config.data_types:
|
||||
base_global_rules.append(
|
||||
"For notes: when updating an existing note use `propose_note_edit` "
|
||||
"(type=append/insert/replace) so the user can review AI changes. "
|
||||
"Only call `update_note` for complete content replacement without review."
|
||||
)
|
||||
global_rules_lines = "\n".join(f"- {r}" for r in base_global_rules)
|
||||
metadata_section = _format_metadata(preprocessed.metadata)
|
||||
|
||||
system_prompt = compile_prompt(
|
||||
|
||||
96
app/core/agent_session_buffer.py
Normal file
96
app/core/agent_session_buffer.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""In-process TTL buffer for per-session LangChain message history.
|
||||
|
||||
Stores the full message list (including AIMessage with tool_calls and ToolMessage)
|
||||
keyed by (user_id, session_id), so agents can reconstruct tool-call context across
|
||||
conversation turns without it being lossy through the wire.
|
||||
|
||||
Single-process only. For multi-worker deployments, replace the _SessionBuffer
|
||||
implementation with one backed by Redis (serialize LangChain messages to dicts via
|
||||
message_to_dict / messages_from_dict from langchain_core.messages).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from threading import Lock
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
SESSION_TTL_SECONDS = 1800 # 30-minute idle expiry
|
||||
MAX_MESSAGES_PER_SESSION = 80 # cap to avoid unbounded memory growth
|
||||
|
||||
|
||||
class _SessionBuffer:
|
||||
def __init__(self) -> None:
|
||||
self._store: dict[tuple[str, str], tuple[float, list[BaseMessage]]] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _evict_stale(self) -> None:
|
||||
now = time.monotonic()
|
||||
stale = [k for k, (ts, _) in self._store.items() if now - ts > SESSION_TTL_SECONDS]
|
||||
for k in stale:
|
||||
del self._store[k]
|
||||
|
||||
def get(self, user_id: str, session_id: str) -> list[BaseMessage] | None:
|
||||
key = (user_id, session_id)
|
||||
with self._lock:
|
||||
entry = self._store.get(key)
|
||||
if entry is None:
|
||||
return None
|
||||
ts, msgs = entry
|
||||
if time.monotonic() - ts > SESSION_TTL_SECONDS:
|
||||
del self._store[key]
|
||||
return None
|
||||
self._store[key] = (time.monotonic(), msgs)
|
||||
return list(msgs)
|
||||
|
||||
def set(self, user_id: str, session_id: str, messages: list[BaseMessage]) -> None:
|
||||
key = (user_id, session_id)
|
||||
capped = messages[-MAX_MESSAGES_PER_SESSION:]
|
||||
with self._lock:
|
||||
self._evict_stale()
|
||||
self._store[key] = (time.monotonic(), capped)
|
||||
|
||||
def clear(self, user_id: str, session_id: str) -> None:
|
||||
with self._lock:
|
||||
self._store.pop((user_id, session_id), None)
|
||||
|
||||
def append_system_message(self, user_id: str, session_id: str, text: str) -> None:
|
||||
"""Append a synthetic system message to the buffer for the given session.
|
||||
|
||||
Creates the session slot if it does not yet exist. Used by the
|
||||
contextual_scope_update handler to inject navigation events without
|
||||
making an LLM call.
|
||||
"""
|
||||
from langchain_core.messages import SystemMessage # noqa: PLC0415
|
||||
|
||||
key = (user_id, session_id)
|
||||
with self._lock:
|
||||
entry = self._store.get(key)
|
||||
if entry is None:
|
||||
msgs: list[BaseMessage] = [SystemMessage(content=text)]
|
||||
else:
|
||||
_, existing = entry
|
||||
msgs = list(existing) + [SystemMessage(content=text)]
|
||||
capped = msgs[-MAX_MESSAGES_PER_SESSION:]
|
||||
self._store[key] = (time.monotonic(), capped)
|
||||
|
||||
|
||||
class ContextualBufferProxy:
|
||||
"""Thin wrapper around _SessionBuffer that closes over user_id + session_id.
|
||||
|
||||
Returned by get_session_buffer() so callers can call
|
||||
``proxy.append_system_message(text)`` without threading user_id/session_id
|
||||
through every call site.
|
||||
"""
|
||||
|
||||
def __init__(self, buf: "_SessionBuffer", user_id: str, session_id: str) -> None:
|
||||
self._buf = buf
|
||||
self._user_id = user_id
|
||||
self._session_id = session_id
|
||||
|
||||
def append_system_message(self, text: str) -> None:
|
||||
self._buf.append_system_message(self._user_id, self._session_id, text)
|
||||
|
||||
|
||||
# Module-level singleton — same pattern as _pending_states in api/app/api/routes/auth.py
|
||||
session_buffer = _SessionBuffer()
|
||||
228
app/core/brief_agent.py
Normal file
228
app/core/brief_agent.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""Brief agent — produces plain-text home and project status briefs.
|
||||
|
||||
Read-only tool subset only. Never calls _normalize_tagged_list_lines —
|
||||
the brief prompt forbids XML tags, so skipping post-processing is intentional.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import date
|
||||
from typing import Any
|
||||
|
||||
from app.agents.note_agent import NOTE_READ_TOOLS
|
||||
from app.agents.project_agent import PROJECT_READ_TOOLS
|
||||
from app.agents.task_agent import TASK_READ_TOOLS
|
||||
from app.agents.timeline_agent import TIMELINE_READ_TOOLS
|
||||
from app.core.deep_agent import (
|
||||
_language_instruction,
|
||||
_proactive_hints_injection,
|
||||
_read_only_memory_tools,
|
||||
_relational_memory_injection,
|
||||
_run_single_agent_stream,
|
||||
_trace_id_from_context,
|
||||
build_brief_multi_project_manifest,
|
||||
)
|
||||
from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback
|
||||
|
||||
_LANGUAGE_NAMES: dict[str, str] = {
|
||||
"en": "English", "it": "Italian", "es": "Spanish",
|
||||
"fr": "French", "de": "German",
|
||||
"english": "English", "italian": "Italian", "italiano": "Italian",
|
||||
"spanish": "Spanish", "español": "Spanish",
|
||||
"french": "French", "français": "French",
|
||||
"german": "German", "deutsch": "German",
|
||||
}
|
||||
|
||||
_HOME_BRIEF_FALLBACK = """\
|
||||
You are the user's personal assistant producing a short daily brief.
|
||||
|
||||
ROLE
|
||||
Act like a calm, attentive secretary writing a stand-up note for your boss.
|
||||
Warm and human, never breezy. Never cheerful filler, never emojis, never
|
||||
"here is your brief" meta-text. The user is opening the app mid-workday and
|
||||
is probably stressed — your job is to lower cognitive load, not add noise.
|
||||
|
||||
TOOLS — always call before writing
|
||||
Pull fresh data every run. Do not invent counts or titles. Use at minimum:
|
||||
- list_tasks_due_today — tasks the user owes today
|
||||
- list_timelines_today — events starting or ending today
|
||||
- list_all_projects — projects currently in progress or at risk
|
||||
- memory_list_blocks / memory_get — personal context about people, clients,
|
||||
payment habits, working preferences
|
||||
If a tool returns nothing, simply omit that topic. Never report zeros.
|
||||
|
||||
WHAT TO INCLUDE
|
||||
1. Tasks due today (title + priority; group the 1-2 most important).
|
||||
2. Timeline events starting or ending today (and anything that starts/ends
|
||||
tomorrow if the user has a very light day).
|
||||
3. Active projects that need a nudge — stalled, blocked, or awaiting input.
|
||||
4. Memory-aware colour where it sharpens the brief. Examples:
|
||||
- "Client Rossi tends to pay late — the Acme invoice is 6 days out."
|
||||
- "You usually dislike meetings before 10:00 — the call at 09:30 is unusual."
|
||||
Only add a memory line when it changes what the user does. Do not pad.
|
||||
|
||||
WHAT TO OMIT
|
||||
- Zero-counts ("no overdue items", "0 meetings today").
|
||||
- Statistics ("2 active projects, 3 completed tasks").
|
||||
- Headers, titles, greetings, sign-offs, dates, emojis, slang.
|
||||
- Meta-phrases ("here is", "let me know if", "hope this helps").
|
||||
- XML/HTML tags of any kind. Plain prose only.
|
||||
|
||||
LIGHT-DAY CLAUSE
|
||||
If tasks + events + active-project-nudges together produce fewer than two
|
||||
sentences of content, also list 1-2 projects in status on_hold or waiting
|
||||
and ask a single, specific question about them — e.g. "Is the Bianchi
|
||||
redesign still paused, or ready to pick back up?" One question max, grounded
|
||||
in a real project name.
|
||||
|
||||
VOICE
|
||||
- Calm. Concise. Human. Short sentences.
|
||||
- Use **bold** sparingly for task titles, project names, and people's names.
|
||||
- No bullet lists. Flow as 2-4 sentences of prose.
|
||||
|
||||
LENGTH
|
||||
2-4 sentences total. Hard cap 4. If the day is truly empty, one sentence.
|
||||
|
||||
Respond in the user's language ({language}). Today is {today}.\
|
||||
"""
|
||||
|
||||
_PROJECT_BRIEF_FALLBACK = """\
|
||||
You are the project assistant producing a short status brief for ONE project.
|
||||
|
||||
ROLE
|
||||
A senior project manager summarising state-of-play for the owner. Factual,
|
||||
sharp, forward-looking. Never reassuring filler, never emojis.
|
||||
|
||||
SCOPE
|
||||
Work only with project_id = {project_id}. Do not mention or pull data from
|
||||
other projects. Use tools to fetch fresh data:
|
||||
- get_project — current status, dates, description
|
||||
- list_tasks(project_id) — open work, split by status
|
||||
- list_timelines(project_id) — milestones hit, upcoming, overdue
|
||||
- list_notes(project_id) — any recent decisions or blockers
|
||||
- memory_get — relevant context about the client, collaborators, constraints
|
||||
|
||||
STRUCTURE — follow exactly, one short paragraph per section, no headers
|
||||
1. **State.** One sentence: current phase, health (on track / at risk / blocked),
|
||||
and why. Cite the concrete signal (overdue milestone, stalled tasks, recent
|
||||
blocker note).
|
||||
2. **What's moving.** What was completed or progressed recently. Name specific
|
||||
tasks or milestones.
|
||||
3. **Next steps.** The 1-3 most important things the user should do next, in
|
||||
priority order. Be concrete — task name, who owns it, when due if known.
|
||||
If waiting on someone else, name them and what the ask is.
|
||||
4. **Risks / memory-flagged items.** One line max. Only include when there is
|
||||
a real risk or a relevant memory (e.g. late-paying client, tight deadline,
|
||||
scope change). Omit the section entirely if nothing to say.
|
||||
|
||||
WHAT TO OMIT
|
||||
- Zero-counts ("no overdue tasks").
|
||||
- Generic advice ("keep up the good work").
|
||||
- Greetings, headers, bullet lists, emojis, sign-offs, meta-phrases.
|
||||
- XML/HTML tags or bracketed id lists. Plain prose only.
|
||||
|
||||
VOICE
|
||||
- Direct. Factual. No fluff.
|
||||
- Use **bold** sparingly for task titles, milestone names, and the owner's name.
|
||||
- Short sentences. Prefer verbs over nouns ("Client review is blocking release"
|
||||
not "There is a blocker which is the client review").
|
||||
|
||||
LENGTH
|
||||
4-8 sentences total across the 3-4 sections. Hard cap 8.
|
||||
|
||||
Respond in the user's language ({language}). Today is {today}.\
|
||||
"""
|
||||
|
||||
|
||||
def _resolve_language(context: dict[str, Any]) -> str:
|
||||
core = context.get("core_memory") or {}
|
||||
raw = (core.get("language") or "en").strip().lower()
|
||||
return _LANGUAGE_NAMES.get(raw, raw.title()) or "English"
|
||||
|
||||
|
||||
def _build_read_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||
return [
|
||||
*TASK_READ_TOOLS,
|
||||
*PROJECT_READ_TOOLS,
|
||||
*TIMELINE_READ_TOOLS,
|
||||
*NOTE_READ_TOOLS,
|
||||
*_read_only_memory_tools(user_id, trace_id),
|
||||
]
|
||||
|
||||
|
||||
async def run_home_brief(
|
||||
user_id: str,
|
||||
context: dict[str, Any],
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""Stream a plain-text daily home brief.
|
||||
|
||||
Yields (event_type, data) tuples identical to _run_single_agent_stream.
|
||||
Do NOT post-process output through _normalize_tagged_list_lines.
|
||||
"""
|
||||
from app.agents.folder_agent import FOLDER_TOOLS
|
||||
|
||||
trace_id = _trace_id_from_context(context)
|
||||
today = date.today().isoformat()
|
||||
language = _resolve_language(context)
|
||||
|
||||
raw_template, langfuse_prompt = get_prompt_or_fallback("home_brief", _HOME_BRIEF_FALLBACK)
|
||||
system_prompt = compile_prompt(raw_template, langfuse_prompt, language=language, today=today)
|
||||
system_prompt += _relational_memory_injection(context)
|
||||
system_prompt += _proactive_hints_injection(context)
|
||||
system_prompt += _language_instruction(context)
|
||||
if today not in system_prompt:
|
||||
system_prompt += f"\nToday is {today}."
|
||||
|
||||
brief_manifest = await build_brief_multi_project_manifest()
|
||||
system_prompt = system_prompt + ("\n\n" + brief_manifest if brief_manifest else "")
|
||||
|
||||
tools = [*_build_read_tools(user_id, trace_id), *FOLDER_TOOLS]
|
||||
async for event in _run_single_agent_stream(
|
||||
user_id=user_id,
|
||||
system_prompt=system_prompt,
|
||||
message="Generate the daily brief.",
|
||||
context=context,
|
||||
langfuse_prompt=langfuse_prompt,
|
||||
agent_name="brief-agent",
|
||||
tools=tools,
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
async def run_project_brief(
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
context: dict[str, Any],
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""Stream a plain-text project status brief for project_id.
|
||||
|
||||
Yields (event_type, data) tuples identical to _run_single_agent_stream.
|
||||
Do NOT post-process output through _normalize_tagged_list_lines.
|
||||
"""
|
||||
trace_id = _trace_id_from_context(context)
|
||||
today = date.today().isoformat()
|
||||
language = _resolve_language(context)
|
||||
|
||||
raw_template, langfuse_prompt = get_prompt_or_fallback("project_brief", _PROJECT_BRIEF_FALLBACK)
|
||||
system_prompt = compile_prompt(
|
||||
raw_template, langfuse_prompt,
|
||||
language=language, today=today, project_id=project_id,
|
||||
)
|
||||
system_prompt += _relational_memory_injection(context)
|
||||
system_prompt += _proactive_hints_injection(context)
|
||||
system_prompt += _language_instruction(context)
|
||||
if today not in system_prompt:
|
||||
system_prompt += f"\nToday is {today}."
|
||||
|
||||
tools = _build_read_tools(user_id, trace_id)
|
||||
async for event in _run_single_agent_stream(
|
||||
user_id=user_id,
|
||||
system_prompt=system_prompt,
|
||||
message=f"Generate the project status brief for project {project_id}.",
|
||||
context=context,
|
||||
langfuse_prompt=langfuse_prompt,
|
||||
agent_name="brief-agent",
|
||||
tools=tools,
|
||||
):
|
||||
yield event
|
||||
File diff suppressed because it is too large
Load Diff
34
app/core/embeddings.py
Normal file
34
app/core/embeddings.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""OpenAI embedding helper for associative memory tier.
|
||||
|
||||
Single public function: ``embed_text(text) -> list[float] | None``.
|
||||
Returns None on any failure — callers must implement a keyword fallback.
|
||||
Never raises; all exceptions are logged as warnings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MAX_INPUT_CHARS = 8000
|
||||
_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
|
||||
|
||||
async def embed_text(text: str) -> list[float] | None:
|
||||
"""Call OpenAI text-embedding-3-small. Return None on failure (caller falls back to keyword)."""
|
||||
try:
|
||||
client = AsyncOpenAI()
|
||||
truncated = text[:_MAX_INPUT_CHARS]
|
||||
response = await client.embeddings.create(
|
||||
input=truncated,
|
||||
model=_EMBEDDING_MODEL,
|
||||
)
|
||||
result: list[float] = response.data[0].embedding
|
||||
logger.debug("embeddings: embed_text dims=%d", len(result))
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.warning("embeddings: embed_text failed: %s", exc)
|
||||
return None
|
||||
183
app/core/folder_indexer.py
Normal file
183
app/core/folder_indexer.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Per-file summarisation for project folder integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from pypdf import PdfReader
|
||||
from docx import Document as DocxDocument
|
||||
|
||||
from app.core.langfuse_client import (
|
||||
compile_prompt,
|
||||
extract_usage,
|
||||
get_langfuse,
|
||||
get_prompt_or_fallback,
|
||||
)
|
||||
from app.core.llm import get_llm
|
||||
|
||||
_TEXT_FALLBACK = (
|
||||
"You are summarising a file for an AI assistant that helps the user manage a project.\n"
|
||||
"Produce a single sentence (<=30 words, <=200 chars) that captures the file's purpose "
|
||||
"and most important detail.\nFile extension: {ext}\nFile name: {name}\nContent (truncated if long):\n{content}"
|
||||
)
|
||||
_IMAGE_FALLBACK = (
|
||||
"You are summarising an image attached to a project folder.\n"
|
||||
"Produce a single sentence (<=30 words, <=200 chars) describing what the image shows "
|
||||
"and any obvious purpose (logo, screenshot, diagram, photo of a whiteboard, etc.)."
|
||||
)
|
||||
_MAX_INPUT_CHARS = 6000
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexResult:
|
||||
summary: str
|
||||
tokens_used: int
|
||||
|
||||
|
||||
async def _llm_text(messages: list) -> object:
|
||||
"""Make the LLM call for text summarisation.
|
||||
|
||||
Defined as a standalone async function so tests can patch it cleanly
|
||||
without needing to mock the LLM object itself.
|
||||
"""
|
||||
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
|
||||
return await llm.ainvoke(messages)
|
||||
|
||||
|
||||
async def _llm_vision(messages: list) -> object:
|
||||
"""Make the LLM call for vision (image) summarisation.
|
||||
|
||||
Accepts the message list and returns the response directly, mirroring
|
||||
the ``_llm_text`` caller pattern so tests can patch it at the module level.
|
||||
"""
|
||||
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
|
||||
return await llm.ainvoke(messages)
|
||||
|
||||
|
||||
async def summarize_image(*, image_b64: str, mime: str, file_name: str | None = None) -> IndexResult:
|
||||
"""Return a compact summary of an image file using vision.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image_b64:
|
||||
Base64-encoded image bytes.
|
||||
mime:
|
||||
MIME type of the image, e.g. ``"image/png"``.
|
||||
file_name:
|
||||
Optional file name, attached to the Langfuse trace as input metadata.
|
||||
"""
|
||||
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_image", _IMAGE_FALLBACK)
|
||||
messages = [
|
||||
SystemMessage(content=template),
|
||||
HumanMessage(content=[
|
||||
{"type": "text", "text": "Summarise this image."},
|
||||
{"type": "image_url", "image_url": {"url": f"data:{mime};base64,{image_b64}"}},
|
||||
]),
|
||||
]
|
||||
lf = get_langfuse()
|
||||
if lf is not None:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="folder-summarize-image",
|
||||
model="gpt-4o-mini",
|
||||
prompt=prompt_obj,
|
||||
input={"file_name": file_name, "mime": mime},
|
||||
) as gen:
|
||||
response = await _llm_vision(messages)
|
||||
usage = extract_usage(response)
|
||||
gen.update(output=response.content, usage_details=usage)
|
||||
else:
|
||||
response = await _llm_vision(messages)
|
||||
usage = extract_usage(response)
|
||||
summary = (response.content or "").strip()[:500]
|
||||
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
|
||||
|
||||
|
||||
async def summarize_text(*, content: str, ext: str, name: str) -> IndexResult:
|
||||
"""Return a compact summary of a text file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
content:
|
||||
Raw text content of the file (will be truncated to _MAX_INPUT_CHARS).
|
||||
ext:
|
||||
File extension including the leading dot, e.g. ``".md"``.
|
||||
name:
|
||||
File name, e.g. ``"kickoff.md"``.
|
||||
"""
|
||||
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_text", _TEXT_FALLBACK)
|
||||
truncated = content[:_MAX_INPUT_CHARS]
|
||||
compiled = compile_prompt(template, prompt_obj, ext=ext, name=name, content=truncated)
|
||||
messages = [
|
||||
SystemMessage(content=compiled),
|
||||
HumanMessage(content="Summarise this file."),
|
||||
]
|
||||
lf = get_langfuse()
|
||||
if lf is not None:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="folder-summarize-text",
|
||||
model="gpt-4o-mini",
|
||||
prompt=prompt_obj,
|
||||
input={"file_name": name, "ext": ext, "content_chars": len(truncated)},
|
||||
) as gen:
|
||||
response = await _llm_text(messages)
|
||||
usage = extract_usage(response)
|
||||
gen.update(output=response.content, usage_details=usage)
|
||||
else:
|
||||
response = await _llm_text(messages)
|
||||
usage = extract_usage(response)
|
||||
summary = (response.content or "").strip()[:500]
|
||||
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
|
||||
|
||||
|
||||
def _extract_pdf_text(pdf_b64: str) -> str:
|
||||
buf = io.BytesIO(base64.b64decode(pdf_b64))
|
||||
reader = PdfReader(buf)
|
||||
parts: list[str] = []
|
||||
for page in reader.pages:
|
||||
try:
|
||||
parts.append(page.extract_text() or "")
|
||||
except Exception:
|
||||
continue
|
||||
return "\n".join(parts).strip()
|
||||
|
||||
|
||||
def _extract_docx_text(docx_b64: str) -> str:
|
||||
buf = io.BytesIO(base64.b64decode(docx_b64))
|
||||
doc = DocxDocument(buf)
|
||||
return "\n".join(p.text for p in doc.paragraphs if p.text).strip()
|
||||
|
||||
|
||||
async def summarize_pdf(*, pdf_b64: str, name: str) -> IndexResult:
|
||||
"""Return a compact summary of a PDF file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pdf_b64:
|
||||
Base64-encoded PDF bytes.
|
||||
name:
|
||||
File name, e.g. ``"report.pdf"``.
|
||||
"""
|
||||
text = _extract_pdf_text(pdf_b64)
|
||||
if not text:
|
||||
return IndexResult(summary="Could not extract text", tokens_used=0)
|
||||
return await summarize_text(content=text, ext=".pdf", name=name)
|
||||
|
||||
|
||||
async def summarize_docx(*, docx_b64: str, name: str) -> IndexResult:
|
||||
"""Return a compact summary of a DOCX file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
docx_b64:
|
||||
Base64-encoded DOCX bytes.
|
||||
name:
|
||||
File name, e.g. ``"spec.docx"``.
|
||||
"""
|
||||
text = _extract_docx_text(docx_b64)
|
||||
if not text:
|
||||
return IndexResult(summary="Could not extract text", tokens_used=0)
|
||||
return await summarize_text(content=text, ext=".docx", name=name)
|
||||
@@ -51,6 +51,10 @@ def _api_key_for_model(model: str) -> str | None:
|
||||
return settings.GOOGLE_API_KEY or None
|
||||
if model.startswith("cerebras/"):
|
||||
return settings.CEREBRAS_API_KEY or None
|
||||
if model.startswith("groq/"):
|
||||
return settings.GROQ_API_KEY or None
|
||||
if model.startswith("deepseek/"):
|
||||
return settings.DEEPSEEK_API_KEY or None
|
||||
if model.startswith("github_copilot/"):
|
||||
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
|
||||
# No API key is required; returning None lets LiteLLM handle auth.
|
||||
@@ -99,10 +103,15 @@ def get_llm(
|
||||
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
||||
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
|
||||
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
|
||||
"floating-agent": lambda: settings.LLM_MODEL_FLOATING_AGENT or settings.LLM_MODEL,
|
||||
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
||||
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
||||
"brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL,
|
||||
"task-brief-agent": lambda: settings.LLM_MODEL_TASK_BRIEF_AGENT or settings.LLM_MODEL,
|
||||
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
||||
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
|
||||
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
|
||||
"memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL,
|
||||
"note-summarizer": lambda: "gpt-4o-mini",
|
||||
}
|
||||
|
||||
|
||||
|
||||
450
app/core/memory_extraction.py
Normal file
450
app/core/memory_extraction.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""Mem0-style Extract/Update pipeline — Phase 2.
|
||||
|
||||
Runs after every ``store_episode`` call to distil durable facts, preferences,
|
||||
routines, and relations from the latest conversation turn.
|
||||
|
||||
Entry point: ``run_extraction(db, user_id, last_user_msg, last_assistant_msg, session_id)``
|
||||
|
||||
Design notes
|
||||
------------
|
||||
- Two gpt-4o-mini calls per turn: extract candidates, then decide action per candidate.
|
||||
- Short-circuit: if no existing neighbours → ADD without a second LLM call (cost saving).
|
||||
- Zero-trust: never logs decrypted user content; relation subject/object labels are
|
||||
treated as identifiers (safe to log per spec).
|
||||
- Must not raise into the request path — caller wraps in asyncio.create_task().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.langfuse_client import get_langfuse, get_prompt_or_fallback, extract_usage, langfuse_context
|
||||
from app.core.llm import get_agent_llm, model_for_agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Fallback prompts (used when Langfuse unavailable) ─────────────────────────
|
||||
|
||||
_EXTRACTION_FALLBACK = (
|
||||
"You are a memory extractor for a personal AI secretary. Given the last conversation "
|
||||
"turn, the user's core memory, and recent episode summaries, identify durable facts, "
|
||||
"preferences, routines, and person/project relations worth remembering.\n\n"
|
||||
"Output JSON matching this schema exactly:\n"
|
||||
'{{"candidates": [{{"type": "<fact|preference|relation|routine>", '
|
||||
'"content": "<short canonical statement>", '
|
||||
'"target_tier": "<core|associative|relational|proactive>", '
|
||||
'"subject": null, "predicate": null, "object": null, "confidence": 0.7}}]}}\n\n'
|
||||
"Rules:\n"
|
||||
"- Skip small talk, greetings, one-off questions.\n"
|
||||
"- Max 5 candidates per call.\n"
|
||||
"- Only extract durable information (still true next week).\n"
|
||||
"- For type=relation: subject/predicate/object required.\n"
|
||||
"- Default confidence=0.7.\n\n"
|
||||
"## Last turn\n{last_turn}\n\n"
|
||||
"## Core memory (current)\n{core_memory}\n\n"
|
||||
"## Recent episodes\n{recent_episodes}"
|
||||
)
|
||||
|
||||
_DECIDE_FALLBACK = (
|
||||
"You are a memory update decision engine. Given a new memory candidate and a list of "
|
||||
"existing memories from the same tier, decide what action to take.\n\n"
|
||||
"Respond with exactly one word: ADD, UPDATE, DELETE, or NOOP.\n\n"
|
||||
"- ADD: new information not in existing memories.\n"
|
||||
"- UPDATE: contradicts or supersedes an existing memory.\n"
|
||||
"- DELETE: states something is no longer true.\n"
|
||||
"- NOOP: already captured accurately.\n\n"
|
||||
"## New candidate\n{candidate}\n\n"
|
||||
"## Existing memories (same tier, top neighbours)\n{existing_memories}"
|
||||
)
|
||||
|
||||
|
||||
# ── Pydantic schemas ───────────────────────────────────────────────────────────
|
||||
|
||||
class MemoryCandidate(BaseModel):
|
||||
type: Literal["fact", "preference", "relation", "routine"]
|
||||
content: str
|
||||
target_tier: Literal["core", "associative", "relational", "proactive"]
|
||||
subject: str | None = None
|
||||
predicate: str | None = None
|
||||
object: str | None = None
|
||||
confidence: float = Field(default=0.7, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class ExtractionResult(BaseModel):
|
||||
candidates: list[MemoryCandidate] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ── Task 2.1 — Extract candidates ─────────────────────────────────────────────
|
||||
|
||||
async def extract_candidates(
|
||||
last_turn: str,
|
||||
core_memory: dict[str, str],
|
||||
recent_episodes: list[str],
|
||||
) -> ExtractionResult:
|
||||
"""Call gpt-4o-mini to extract memory candidates from the latest turn.
|
||||
|
||||
Returns an ExtractionResult (may be empty on failure — never raises).
|
||||
"""
|
||||
core_str = "\n".join(f"{k}: {v}" for k, v in core_memory.items()) or "(empty)"
|
||||
episodes_str = "\n---\n".join(recent_episodes[-5:]) or "(none)"
|
||||
|
||||
template, prompt_obj = get_prompt_or_fallback("memory_extraction", _EXTRACTION_FALLBACK)
|
||||
|
||||
# Compile with Langfuse variable syntax ({{var}}) or fallback {var}
|
||||
if prompt_obj is not None:
|
||||
try:
|
||||
system_text = prompt_obj.compile(
|
||||
last_turn=last_turn,
|
||||
core_memory=core_str,
|
||||
recent_episodes=episodes_str,
|
||||
)
|
||||
if isinstance(system_text, list):
|
||||
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: compile failed: %s", exc)
|
||||
system_text = template.format(
|
||||
last_turn=last_turn,
|
||||
core_memory=core_str,
|
||||
recent_episodes=episodes_str,
|
||||
)
|
||||
else:
|
||||
system_text = template.format(
|
||||
last_turn=last_turn,
|
||||
core_memory=core_str,
|
||||
recent_episodes=episodes_str,
|
||||
)
|
||||
|
||||
llm = get_agent_llm("memory-extractor", temperature=0)
|
||||
# Bind JSON mode so the model always returns parseable output.
|
||||
llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined]
|
||||
|
||||
lf = get_langfuse()
|
||||
try:
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Extract memory candidates as JSON."),
|
||||
]
|
||||
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-extraction",
|
||||
model=model_for_agent("memory-extractor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm_json.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm_json.ainvoke(messages)
|
||||
|
||||
raw = json.loads(response.content)
|
||||
result = ExtractionResult.model_validate(raw)
|
||||
logger.info("memory_extraction: extracted %d candidates", len(result.candidates))
|
||||
return result
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: extract_candidates failed: %s", exc)
|
||||
return ExtractionResult(candidates=[])
|
||||
|
||||
|
||||
# ── Task 2.2 — Decide action ──────────────────────────────────────────────────
|
||||
|
||||
async def decide_action(
|
||||
candidate: MemoryCandidate,
|
||||
existing: list[str],
|
||||
) -> Literal["ADD", "UPDATE", "DELETE", "NOOP"]:
|
||||
"""Decide what to do with a candidate given existing memories in the same tier.
|
||||
|
||||
Short-circuits to ADD without an LLM call when existing is empty (cost saving).
|
||||
Never raises.
|
||||
"""
|
||||
if not existing:
|
||||
return "ADD"
|
||||
|
||||
candidate_str = f"[{candidate.type}] {candidate.content}"
|
||||
existing_str = "\n".join(f"- {m}" for m in existing)
|
||||
|
||||
template, prompt_obj = get_prompt_or_fallback("memory_decide_action", _DECIDE_FALLBACK)
|
||||
|
||||
if prompt_obj is not None:
|
||||
try:
|
||||
system_text = prompt_obj.compile(
|
||||
candidate=candidate_str,
|
||||
existing_memories=existing_str,
|
||||
)
|
||||
if isinstance(system_text, list):
|
||||
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: decide compile failed: %s", exc)
|
||||
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
||||
else:
|
||||
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
||||
|
||||
llm = get_agent_llm("memory-extractor", temperature=0)
|
||||
lf = get_langfuse()
|
||||
|
||||
try:
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Decide action."),
|
||||
]
|
||||
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-decide-action",
|
||||
model=model_for_agent("memory-extractor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm.ainvoke(messages)
|
||||
|
||||
verb = response.content.strip().upper()
|
||||
if verb in ("ADD", "UPDATE", "DELETE", "NOOP"):
|
||||
return verb # type: ignore[return-value]
|
||||
logger.warning("memory_extraction: unexpected decide verb=%r, defaulting ADD", verb)
|
||||
return "ADD"
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: decide_action failed: %s", exc)
|
||||
return "ADD"
|
||||
|
||||
|
||||
# ── Task 2.3 — Pipeline orchestrator ──────────────────────────────────────────
|
||||
|
||||
async def run_extraction(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
last_user_msg: str,
|
||||
last_assistant_msg: str,
|
||||
session_id: str | None,
|
||||
) -> None:
|
||||
"""Full Mem0-style extract/update pipeline for one conversation turn.
|
||||
|
||||
Steps:
|
||||
1. Load core memory + last 5 episodes.
|
||||
2. extract_candidates() → up to 5 MemoryCandidate objects.
|
||||
3. For each candidate: find top-3 neighbours → decide_action() → apply.
|
||||
4. Trace via Langfuse.
|
||||
|
||||
Never raises — wraps everything in try/except.
|
||||
"""
|
||||
try:
|
||||
await _run_extraction_inner(db, user_id, last_user_msg, last_assistant_msg, session_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: run_extraction failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _run_extraction_inner(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
last_user_msg: str,
|
||||
last_assistant_msg: str,
|
||||
session_id: str | None,
|
||||
) -> None:
|
||||
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
|
||||
|
||||
middleware = MemoryMiddleware(db)
|
||||
fernet = await middleware._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
logger.warning("memory_extraction: no fernet for user=%s, skipping", user_id)
|
||||
return
|
||||
|
||||
# 1. Load context
|
||||
core: dict[str, str] = await middleware._load_core(user_id, fernet)
|
||||
episodes: list[str] = await middleware._load_episodic(user_id, fernet, session_id=session_id)
|
||||
|
||||
last_turn = f"User: {last_user_msg}\nAssistant: {last_assistant_msg}"
|
||||
|
||||
lf = get_langfuse()
|
||||
|
||||
async def _run(trace_id: str | None) -> dict[str, Any]:
|
||||
# 2. Extract candidates
|
||||
result = await extract_candidates(last_turn, core, episodes)
|
||||
if not result.candidates:
|
||||
logger.info("memory_extraction: no candidates user=%s", user_id)
|
||||
return {"candidates": 0, "applied": 0}
|
||||
|
||||
logger.info(
|
||||
"memory_extraction: processing %d candidates user=%s trace=%s",
|
||||
len(result.candidates),
|
||||
user_id,
|
||||
trace_id or "-",
|
||||
)
|
||||
|
||||
# 3. Apply each candidate
|
||||
applied = 0
|
||||
actions: list[str] = []
|
||||
for candidate in result.candidates:
|
||||
try:
|
||||
await _apply_candidate(middleware, db, user_id, fernet, candidate, trace_id)
|
||||
applied += 1
|
||||
actions.append(f"{candidate.type}:{candidate.target_tier}")
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_extraction: apply failed candidate=%r user=%s: %s",
|
||||
candidate.content[:80],
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"memory_extraction: applied %d/%d candidates user=%s",
|
||||
applied,
|
||||
len(result.candidates),
|
||||
user_id,
|
||||
)
|
||||
return {"candidates": len(result.candidates), "applied": applied, "actions": actions}
|
||||
|
||||
with langfuse_context(user_id=user_id, session_id=session_id):
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="span",
|
||||
name="memory-extraction-pipeline",
|
||||
input={"last_turn_preview": last_turn[:200]},
|
||||
) as span:
|
||||
summary = await _run(trace_id=span.id)
|
||||
span.update(output=summary)
|
||||
try:
|
||||
lf.flush()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
await _run(trace_id=None)
|
||||
|
||||
|
||||
async def _apply_candidate(
|
||||
middleware: Any,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
fernet: Any,
|
||||
candidate: MemoryCandidate,
|
||||
trace_id: str | None,
|
||||
) -> None:
|
||||
"""Fetch neighbours, decide action, apply to the appropriate tier."""
|
||||
|
||||
neighbours: list[str] = []
|
||||
|
||||
if candidate.target_tier == "core":
|
||||
# For core tier: neighbours are existing core block values for similar keys.
|
||||
blocks = await middleware.list_core_blocks(user_id)
|
||||
neighbours = [b["value"] for b in blocks[:3]]
|
||||
|
||||
elif candidate.target_tier == "associative":
|
||||
neighbours = await middleware.search_archival(user_id, candidate.content, top_k=3)
|
||||
|
||||
elif candidate.target_tier == "relational":
|
||||
# Relation candidates handled specially — passed to upsert_relation directly.
|
||||
# Neighbours: search by subject label if available.
|
||||
neighbours = []
|
||||
|
||||
elif candidate.target_tier == "proactive":
|
||||
neighbours = await middleware.search_recall(user_id, candidate.content, top_k=3)
|
||||
|
||||
action = await decide_action(candidate, neighbours)
|
||||
logger.info(
|
||||
"memory_extraction: candidate type=%s tier=%s action=%s",
|
||||
candidate.type,
|
||||
candidate.target_tier,
|
||||
action,
|
||||
)
|
||||
|
||||
if action == "NOOP":
|
||||
return
|
||||
|
||||
if candidate.target_tier == "relational":
|
||||
# Always upsert relations — decide_action skipped (no neighbour search).
|
||||
if candidate.subject and candidate.predicate and candidate.object:
|
||||
await _upsert_relation(
|
||||
middleware, db, user_id, candidate, trace_id
|
||||
)
|
||||
return
|
||||
|
||||
if action in ("ADD", "UPDATE"):
|
||||
if candidate.target_tier == "core":
|
||||
# Derive a short key from the content (first 40 chars, snake_cased).
|
||||
key = _content_to_key(candidate.content)
|
||||
await middleware.update_core(user_id, key, candidate.content, trace_id=trace_id)
|
||||
|
||||
elif candidate.target_tier == "associative":
|
||||
await middleware.store_associative(user_id, candidate.content)
|
||||
|
||||
elif candidate.target_tier == "proactive":
|
||||
await _store_proactive_stub(middleware, db, user_id, candidate, fernet)
|
||||
|
||||
elif action == "DELETE":
|
||||
if candidate.target_tier == "core":
|
||||
key = _content_to_key(candidate.content)
|
||||
await middleware.delete_core(user_id, key)
|
||||
|
||||
|
||||
def _content_to_key(content: str) -> str:
|
||||
"""Derive a short snake_case key from a content string (first 40 chars)."""
|
||||
import re # noqa: PLC0415
|
||||
slug = re.sub(r"[^a-z0-9]+", "_", content[:40].lower()).strip("_")
|
||||
return slug or "memory"
|
||||
|
||||
|
||||
async def _upsert_relation(
|
||||
middleware: Any,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
candidate: MemoryCandidate,
|
||||
trace_id: str | None,
|
||||
) -> None:
|
||||
"""Upsert a relation row via MemoryMiddleware.upsert_relation (Phase 3)."""
|
||||
await middleware.upsert_relation(
|
||||
user_id=user_id,
|
||||
subject=candidate.subject or "unknown",
|
||||
subject_type="unknown",
|
||||
predicate=candidate.predicate or "related_to",
|
||||
object_=candidate.object or "unknown",
|
||||
object_type="unknown",
|
||||
confidence=candidate.confidence,
|
||||
)
|
||||
logger.info(
|
||||
"memory_extraction: upserted relation subject=%s predicate=%s object=%s",
|
||||
candidate.subject,
|
||||
candidate.predicate,
|
||||
candidate.object,
|
||||
)
|
||||
|
||||
|
||||
async def _store_proactive_stub(
|
||||
middleware: Any,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
candidate: MemoryCandidate,
|
||||
fernet: Any,
|
||||
) -> None:
|
||||
"""Store a proactive pattern row directly (MemoryProactive model)."""
|
||||
import uuid # noqa: PLC0415
|
||||
from app.models import MemoryProactive # noqa: PLC0415
|
||||
from app.core.memory_middleware import _encrypt # noqa: PLC0415
|
||||
|
||||
encrypted = _encrypt(fernet, candidate.content)
|
||||
row = MemoryProactive(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
pattern_encrypted=encrypted,
|
||||
confidence=candidate.confidence,
|
||||
source="inferred",
|
||||
)
|
||||
db.add(row)
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info("memory_extraction: stored proactive pattern user=%s", user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: store proactive failed: %s", exc)
|
||||
await db.rollback()
|
||||
581
app/core/memory_maintenance.py
Normal file
581
app/core/memory_maintenance.py
Normal file
@@ -0,0 +1,581 @@
|
||||
"""Memory maintenance jobs — Phase 3/5.
|
||||
|
||||
Three entrypoints called by the scheduler (APScheduler) registered in app/main.py:
|
||||
|
||||
drain_extraction_queue(db) — Free-tier batch extraction (Phase 2/5).
|
||||
mine_proactive_patterns(db, user_id) — Power+ pattern mining (Phase 5).
|
||||
decay_relations(db, user_id) — confidence decay + pruning for memory_relations (Phase 3).
|
||||
|
||||
All are safe to call manually or from tests; they never raise.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
|
||||
from app.models import MemoryAssociative, MemoryEpisodic, MemoryProactive, MemoryRelation, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Decay parameters for relations
|
||||
_DECAY_FACTOR = 0.95
|
||||
_DECAY_PERIOD_DAYS = 30
|
||||
_PRUNE_THRESHOLD = 0.2
|
||||
|
||||
# Proactive pattern decay: 10 % per 7 days since last sighting
|
||||
_PROACTIVE_DECAY_FACTOR = 0.9
|
||||
_PROACTIVE_DECAY_PERIOD_DAYS = 7
|
||||
_PROACTIVE_PRUNE_THRESHOLD = 0.2
|
||||
|
||||
# Mining: require at least this many episodes to attempt pattern extraction
|
||||
_MIN_EPISODES_FOR_MINING = 3
|
||||
_MINING_LOOKBACK_DAYS = 30
|
||||
|
||||
# Audit: caps to control token cost
|
||||
_AUDIT_MAX_FACTS = 50
|
||||
_AUDIT_MAX_LABELS = 100
|
||||
|
||||
|
||||
async def decay_relations(db: AsyncSession, user_id: str) -> None:
|
||||
"""Apply confidence decay to all relation rows for a user.
|
||||
|
||||
Decay rule: confidence *= 0.95 for every 30 days since last_confirmed_at.
|
||||
Rows whose confidence falls below 0.2 are deleted.
|
||||
|
||||
Never raises — wraps in try/except.
|
||||
"""
|
||||
try:
|
||||
await _decay_relations_inner(db, user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: decay_relations failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None:
|
||||
result = await db.execute(
|
||||
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
now = datetime.now(timezone.utc)
|
||||
deleted = 0
|
||||
decayed = 0
|
||||
|
||||
for row in rows:
|
||||
reference = row.last_confirmed_at or row.created_at
|
||||
if reference is None:
|
||||
continue
|
||||
if reference.tzinfo is None:
|
||||
reference = reference.replace(tzinfo=timezone.utc)
|
||||
|
||||
days_elapsed = (now - reference).days
|
||||
if days_elapsed < _DECAY_PERIOD_DAYS:
|
||||
continue
|
||||
|
||||
periods = days_elapsed // _DECAY_PERIOD_DAYS
|
||||
new_confidence = row.confidence * (_DECAY_FACTOR ** periods)
|
||||
|
||||
if new_confidence < _PRUNE_THRESHOLD:
|
||||
await db.delete(row)
|
||||
deleted += 1
|
||||
logger.info(
|
||||
"memory_maintenance: pruned relation id=%s user=%s subject=%s predicate=%s "
|
||||
"confidence=%.3f (below threshold)",
|
||||
row.id, user_id, row.subject_label, row.predicate, new_confidence,
|
||||
)
|
||||
else:
|
||||
row.confidence = new_confidence
|
||||
decayed += 1
|
||||
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"memory_maintenance: decay_relations user=%s decayed=%d deleted=%d",
|
||||
user_id, decayed, deleted,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: decay_relations commit failed user=%s: %s", user_id, exc)
|
||||
await db.rollback()
|
||||
|
||||
|
||||
async def drain_extraction_queue(db: AsyncSession) -> None:
|
||||
"""Process pending ExtractionQueue rows for Free-tier users.
|
||||
|
||||
Each row corresponds to a stored episode that should be fed through the
|
||||
Mem0-style extraction pipeline. Rows are deleted after successful processing.
|
||||
Never raises — wraps in try/except.
|
||||
"""
|
||||
try:
|
||||
await _drain_extraction_queue_inner(db)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: drain_extraction_queue failed: %s", exc)
|
||||
|
||||
|
||||
async def _drain_extraction_queue_inner(db: AsyncSession) -> None:
|
||||
from app.models import ExtractionQueue # noqa: PLC0415
|
||||
|
||||
result = await db.execute(select(ExtractionQueue))
|
||||
rows = result.scalars().all()
|
||||
|
||||
if not rows:
|
||||
logger.debug("memory_maintenance: drain_extraction_queue nothing to drain")
|
||||
return
|
||||
|
||||
logger.info("memory_maintenance: drain_extraction_queue pending=%d", len(rows))
|
||||
|
||||
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
||||
|
||||
processed = 0
|
||||
for row in rows:
|
||||
try:
|
||||
await run_extraction(
|
||||
db=db,
|
||||
user_id=row.user_id,
|
||||
last_user_msg="",
|
||||
last_assistant_msg="",
|
||||
session_id=None,
|
||||
)
|
||||
await db.delete(row)
|
||||
await db.commit()
|
||||
processed += 1
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: drain failed row=%s user=%s: %s",
|
||||
row.id, row.user_id, exc,
|
||||
)
|
||||
await db.rollback()
|
||||
|
||||
logger.info("memory_maintenance: drain_extraction_queue processed=%d/%d", processed, len(rows))
|
||||
|
||||
|
||||
async def mine_proactive_patterns(db: AsyncSession, user_id: str) -> None:
|
||||
"""Mine recurring behavioral patterns from last 30 days of episodes (Power+ only).
|
||||
|
||||
Steps:
|
||||
1. Gate on proactive_mining tier feature.
|
||||
2. Load + decrypt last 30 days of episodic summaries.
|
||||
3. Call gpt-4o-mini to identify recurring patterns.
|
||||
4. Encrypt and store each pattern in memory_proactive.
|
||||
5. Apply decay to existing proactive rows.
|
||||
|
||||
Never raises — wraps in try/except.
|
||||
"""
|
||||
try:
|
||||
await _mine_proactive_patterns_inner(db, user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: mine_proactive_patterns failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _mine_proactive_patterns_inner(db: AsyncSession, user_id: str) -> None:
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
tier = await tier_manager.get_tier(user_id, db)
|
||||
if not tier_manager.check_feature(tier, "proactive_mining"):
|
||||
logger.debug("memory_maintenance: mine_proactive_patterns skipped (tier=%s)", tier)
|
||||
return
|
||||
|
||||
# Load user Fernet key
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not user.encryption_key:
|
||||
logger.warning("memory_maintenance: mine_proactive_patterns no encryption_key user=%s", user_id)
|
||||
return
|
||||
|
||||
fernet = Fernet(user.encryption_key.encode())
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=_MINING_LOOKBACK_DAYS)
|
||||
|
||||
episodes_result = await db.execute(
|
||||
select(MemoryEpisodic)
|
||||
.where(
|
||||
MemoryEpisodic.user_id == user_id,
|
||||
MemoryEpisodic.created_at >= cutoff,
|
||||
)
|
||||
.order_by(MemoryEpisodic.created_at.asc())
|
||||
)
|
||||
episode_rows = episodes_result.scalars().all()
|
||||
|
||||
if len(episode_rows) < _MIN_EPISODES_FOR_MINING:
|
||||
logger.info(
|
||||
"memory_maintenance: mine_proactive_patterns skipped user=%s episodes=%d (< %d)",
|
||||
user_id, len(episode_rows), _MIN_EPISODES_FOR_MINING,
|
||||
)
|
||||
return
|
||||
|
||||
summaries: list[str] = []
|
||||
for ep in episode_rows:
|
||||
try:
|
||||
plaintext = fernet.decrypt(ep.summary_encrypted.encode()).decode()
|
||||
summaries.append(plaintext)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
|
||||
patterns = await _extract_proactive_patterns(summaries)
|
||||
if not patterns:
|
||||
logger.info("memory_maintenance: mine_proactive_patterns user=%s no patterns extracted", user_id)
|
||||
return
|
||||
|
||||
stored = 0
|
||||
for pattern_text in patterns:
|
||||
try:
|
||||
encrypted = fernet.encrypt(pattern_text.encode()).decode()
|
||||
row = MemoryProactive(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
pattern_encrypted=encrypted,
|
||||
confidence=0.7,
|
||||
source="inferred",
|
||||
)
|
||||
db.add(row)
|
||||
stored += 1
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: failed to store pattern user=%s: %s", user_id, exc)
|
||||
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"memory_maintenance: mine_proactive_patterns user=%s stored=%d",
|
||||
user_id, stored,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: mine_proactive_patterns commit failed user=%s: %s", user_id, exc)
|
||||
await db.rollback()
|
||||
return
|
||||
|
||||
await _decay_proactive_patterns(db, user_id, fernet)
|
||||
|
||||
|
||||
async def _extract_proactive_patterns(summaries: list[str]) -> list[str]:
|
||||
"""Call memory-miner LLM to identify recurring behavioral/temporal patterns."""
|
||||
from app.core.llm import get_agent_llm # noqa: PLC0415
|
||||
|
||||
llm = get_agent_llm("memory-miner", temperature=0)
|
||||
combined = "\n---\n".join(summaries[-20:]) # cap at last 20 to control token usage
|
||||
prompt = (
|
||||
"You are analyzing conversation history for a personal AI secretary. "
|
||||
"Identify 3-5 recurring temporal or behavioral patterns (e.g. 'always works late on Thursdays', "
|
||||
"'prefers bullet-point summaries', 'frequently asks about Project Acme status'). "
|
||||
"Return each pattern as a plain, short English sentence on its own line. "
|
||||
"No numbering, no bullet points, no extra text.\n\n"
|
||||
f"Conversation history:\n{combined}"
|
||||
)
|
||||
try:
|
||||
response = await llm.ainvoke(prompt)
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
lines = [line.strip() for line in str(text).splitlines() if line.strip()]
|
||||
return lines[:5]
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: _extract_proactive_patterns LLM failed: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
async def _decay_proactive_patterns(db: AsyncSession, user_id: str, fernet: Fernet) -> None:
|
||||
"""Decay confidence of existing proactive patterns; prune below threshold."""
|
||||
result = await db.execute(
|
||||
select(MemoryProactive).where(MemoryProactive.user_id == user_id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
now = datetime.now(timezone.utc)
|
||||
deleted = 0
|
||||
decayed = 0
|
||||
|
||||
for row in rows:
|
||||
reference = row.created_at
|
||||
if reference is None:
|
||||
continue
|
||||
if reference.tzinfo is None:
|
||||
reference = reference.replace(tzinfo=timezone.utc)
|
||||
|
||||
days_elapsed = (now - reference).days
|
||||
if days_elapsed < _PROACTIVE_DECAY_PERIOD_DAYS:
|
||||
continue
|
||||
|
||||
periods = days_elapsed // _PROACTIVE_DECAY_PERIOD_DAYS
|
||||
new_confidence = row.confidence * (_PROACTIVE_DECAY_FACTOR ** periods)
|
||||
|
||||
if new_confidence < _PROACTIVE_PRUNE_THRESHOLD:
|
||||
await db.delete(row)
|
||||
deleted += 1
|
||||
else:
|
||||
row.confidence = new_confidence
|
||||
decayed += 1
|
||||
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"memory_maintenance: decay_proactive user=%s decayed=%d deleted=%d",
|
||||
user_id, decayed, deleted,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: decay_proactive commit failed user=%s: %s", user_id, exc)
|
||||
await db.rollback()
|
||||
|
||||
|
||||
# ── Phase 7: weekly memory audit ──────────────────────────────────────────────
|
||||
|
||||
_AUDIT_CONTRADICTIONS_FALLBACK = (
|
||||
"You are auditing a personal AI assistant's memory bank. "
|
||||
"Each fact has an ID in brackets. "
|
||||
"Find pairs that directly contradict each other "
|
||||
"(e.g. 'prefers morning meetings' vs 'never schedules before noon'). "
|
||||
"For each contradiction, pick the ID to DELETE (the older or less specific one). "
|
||||
'Return ONLY a valid JSON array, no markdown fences: '
|
||||
'[{{"delete": "<id>", "reason": "<one line>"}}]. '
|
||||
"If no contradictions, return [].\n\n"
|
||||
"Facts:\n{facts}"
|
||||
)
|
||||
|
||||
_AUDIT_CANONICALIZE_FALLBACK = (
|
||||
"You are auditing entity labels in a personal AI assistant's relational memory. "
|
||||
"These are names of people, companies, projects, or topics. "
|
||||
"Group labels that clearly refer to the same real-world entity "
|
||||
"(e.g. 'giulia', 'Giulia', 'Giulia R.' → canonical 'Giulia'). "
|
||||
"Return ONLY a valid JSON array, no markdown fences: "
|
||||
'[{{"canonical": "<best label>", "variants": ["<v1>", "<v2>"]}}]. '
|
||||
"Only include groups with at least one variant. Singletons: omit.\n\n"
|
||||
"Labels:\n{labels}"
|
||||
)
|
||||
|
||||
|
||||
async def audit_memory(db: AsyncSession, user_id: str) -> None:
|
||||
"""Weekly audit: contradiction scan on associative facts + label canonicalization on relations.
|
||||
|
||||
Steps:
|
||||
1. Decrypt up to _AUDIT_MAX_FACTS associative rows; send list to memory-auditor LLM.
|
||||
2. LLM flags rows to delete (direct contradictions); hard-delete them.
|
||||
3. Collect unique subject/object labels from memory_relations; ask LLM to group duplicates.
|
||||
4. Rewrite variant labels to their canonical form in-place.
|
||||
|
||||
Never raises — wraps in try/except.
|
||||
"""
|
||||
try:
|
||||
await _audit_memory_inner(db, user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: audit_memory failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _audit_memory_inner(db: AsyncSession, user_id: str) -> None:
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not user.encryption_key:
|
||||
logger.warning("memory_maintenance: audit_memory no encryption_key user=%s", user_id)
|
||||
return
|
||||
|
||||
fernet = Fernet(user.encryption_key.encode())
|
||||
await _scan_associative_contradictions(db, user_id, fernet)
|
||||
await _canonicalize_relation_labels(db, user_id)
|
||||
|
||||
|
||||
async def _scan_associative_contradictions(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
fernet: Fernet,
|
||||
) -> None:
|
||||
"""Decrypt associative facts, ask LLM to flag contradictions, delete superseded rows."""
|
||||
result = await db.execute(
|
||||
select(MemoryAssociative)
|
||||
.where(MemoryAssociative.user_id == user_id)
|
||||
.order_by(MemoryAssociative.updated_at.desc())
|
||||
.limit(_AUDIT_MAX_FACTS)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
if len(rows) < 2:
|
||||
return
|
||||
|
||||
id_to_text: dict[str, str] = {}
|
||||
for row in rows:
|
||||
try:
|
||||
plaintext = fernet.decrypt(row.content_encrypted.encode()).decode()
|
||||
id_to_text[row.id] = plaintext
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if len(id_to_text) < 2:
|
||||
return
|
||||
|
||||
id_list = list(id_to_text.keys())
|
||||
numbered = "\n".join(
|
||||
f"{i + 1}. [{rid}] {id_to_text[rid]}" for i, rid in enumerate(id_list)
|
||||
)
|
||||
|
||||
template, prompt_obj = get_prompt_or_fallback(
|
||||
"memory_audit_contradictions", _AUDIT_CONTRADICTIONS_FALLBACK
|
||||
)
|
||||
system_text = compile_prompt(template, prompt_obj, facts=numbered)
|
||||
|
||||
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
|
||||
llm = get_agent_llm("memory-auditor", temperature=0)
|
||||
lf = get_langfuse()
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Audit facts for contradictions."),
|
||||
]
|
||||
try:
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-audit-contradictions",
|
||||
model=model_for_agent("memory-auditor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm.ainvoke(messages)
|
||||
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
deletions = json.loads(text.strip())
|
||||
if not isinstance(deletions, list):
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: _scan_associative_contradictions LLM/parse failed user=%s: %s",
|
||||
user_id, exc,
|
||||
)
|
||||
return
|
||||
|
||||
deleted = 0
|
||||
for item in deletions:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
rid = item.get("delete")
|
||||
if not rid or rid not in id_to_text:
|
||||
continue
|
||||
result2 = await db.execute(
|
||||
select(MemoryAssociative).where(
|
||||
MemoryAssociative.id == rid,
|
||||
MemoryAssociative.user_id == user_id,
|
||||
)
|
||||
)
|
||||
target = result2.scalar_one_or_none()
|
||||
if target:
|
||||
await db.delete(target)
|
||||
deleted += 1
|
||||
logger.info(
|
||||
"memory_maintenance: audit deleted contradiction id=%s user=%s reason=%s",
|
||||
rid, user_id, item.get("reason", ""),
|
||||
)
|
||||
|
||||
if deleted:
|
||||
try:
|
||||
await db.commit()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: audit contradiction commit failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await db.rollback()
|
||||
|
||||
logger.info(
|
||||
"memory_maintenance: _scan_associative_contradictions user=%s deleted=%d", user_id, deleted
|
||||
)
|
||||
|
||||
|
||||
async def _canonicalize_relation_labels(db: AsyncSession, user_id: str) -> None:
|
||||
"""Group near-duplicate entity labels in memory_relations and unify to canonical form."""
|
||||
result = await db.execute(
|
||||
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
if not rows:
|
||||
return
|
||||
|
||||
all_labels: set[str] = set()
|
||||
for row in rows:
|
||||
all_labels.add(row.subject_label)
|
||||
all_labels.add(row.object_label)
|
||||
|
||||
labels_list = sorted(all_labels)[:_AUDIT_MAX_LABELS]
|
||||
if len(labels_list) < 2:
|
||||
return
|
||||
|
||||
labels_block = "\n".join(f"- {lbl}" for lbl in labels_list)
|
||||
template, prompt_obj = get_prompt_or_fallback(
|
||||
"memory_audit_canonicalize", _AUDIT_CANONICALIZE_FALLBACK
|
||||
)
|
||||
system_text = compile_prompt(template, prompt_obj, labels=labels_block)
|
||||
|
||||
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
|
||||
llm = get_agent_llm("memory-auditor", temperature=0)
|
||||
lf = get_langfuse()
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Canonicalize entity labels."),
|
||||
]
|
||||
try:
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-audit-canonicalize",
|
||||
model=model_for_agent("memory-auditor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm.ainvoke(messages)
|
||||
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
groups = json.loads(text.strip())
|
||||
if not isinstance(groups, list):
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: _canonicalize_relation_labels LLM/parse failed user=%s: %s",
|
||||
user_id, exc,
|
||||
)
|
||||
return
|
||||
|
||||
# Build variant → canonical map
|
||||
remap: dict[str, str] = {}
|
||||
for group in groups:
|
||||
if not isinstance(group, dict):
|
||||
continue
|
||||
canonical = group.get("canonical", "")
|
||||
variants = group.get("variants") or []
|
||||
if not canonical:
|
||||
continue
|
||||
for v in variants:
|
||||
if isinstance(v, str) and v != canonical:
|
||||
remap[v] = canonical
|
||||
|
||||
if not remap:
|
||||
return
|
||||
|
||||
updated = 0
|
||||
for row in rows:
|
||||
changed = False
|
||||
if row.subject_label in remap:
|
||||
row.subject_label = remap[row.subject_label]
|
||||
changed = True
|
||||
if row.object_label in remap:
|
||||
row.object_label = remap[row.object_label]
|
||||
changed = True
|
||||
if changed:
|
||||
updated += 1
|
||||
|
||||
if updated:
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"memory_maintenance: _canonicalize_relation_labels user=%s updated=%d",
|
||||
user_id, updated,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: canonicalize commit failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await db.rollback()
|
||||
@@ -18,8 +18,10 @@ Usage:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
@@ -27,15 +29,22 @@ from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import (
|
||||
ExtractionQueue,
|
||||
MemoryAssociative,
|
||||
MemoryCore,
|
||||
MemoryEpisodic,
|
||||
MemoryProactive,
|
||||
MemoryRelation,
|
||||
User,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
# Tuning constants
|
||||
_ASSOCIATIVE_TOP_K = 5
|
||||
_EPISODIC_RECENT_N = 10
|
||||
@@ -64,26 +73,31 @@ class MemoryMiddleware:
|
||||
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
||||
episodic_memory — [plaintext_summary, ...] (most recent N)
|
||||
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
||||
relational_memory — ["subject --predicate--> object", ...] (top 10, Pro+)
|
||||
"""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return {}
|
||||
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
user_tier: str = user_dbg.get("tier") or "free"
|
||||
|
||||
core = await self._load_core(user_id, fernet)
|
||||
associative = await self._load_associative(user_id, message, fernet)
|
||||
associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier)
|
||||
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||
proactive = await self._load_proactive(user_id, fernet)
|
||||
relational = await self._load_relational(user_id, user_tier=user_tier)
|
||||
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
logger.info(
|
||||
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
||||
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d relational=%d",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
user_dbg.get("tier") or "-",
|
||||
user_tier,
|
||||
len(core),
|
||||
len(associative),
|
||||
len(episodic),
|
||||
len(proactive),
|
||||
len(relational),
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -91,6 +105,7 @@ class MemoryMiddleware:
|
||||
"associative_memory": associative,
|
||||
"episodic_memory": episodic,
|
||||
"proactive_hints": proactive,
|
||||
"relational_memory": relational,
|
||||
}
|
||||
|
||||
async def store_episode(
|
||||
@@ -104,7 +119,10 @@ class MemoryMiddleware:
|
||||
"""Summarise and store a completed interaction in episodic memory.
|
||||
|
||||
The summary is a simple heuristic concatenation (no LLM call) to keep
|
||||
latency low. Full LLM summarisation can be added in a later step.
|
||||
latency low. After committing the episode row, dispatches the Mem0-style
|
||||
extraction pipeline:
|
||||
- Pro/Power/Team → asyncio.create_task (fire-and-forget, realtime).
|
||||
- Free → enqueue an ExtractionQueue row for the daily cron.
|
||||
"""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
@@ -113,26 +131,95 @@ class MemoryMiddleware:
|
||||
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
||||
encrypted = _encrypt(fernet, summary)
|
||||
|
||||
row = MemoryEpisodic(
|
||||
episode = MemoryEpisodic(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
summary_encrypted=encrypted,
|
||||
session_id=session_id,
|
||||
)
|
||||
self._db.add(row)
|
||||
self._db.add(episode)
|
||||
episode_id: str = episode.id
|
||||
try:
|
||||
await self._db.commit()
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
tier = user_dbg.get("tier") or "free"
|
||||
logger.info(
|
||||
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
user_dbg.get("tier") or "-",
|
||||
tier,
|
||||
session_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
return
|
||||
|
||||
# ── Dispatch extraction pipeline (Phase 2) ────────────────────────────
|
||||
await self._dispatch_extraction(
|
||||
user_id=user_id,
|
||||
episode_id=episode_id,
|
||||
last_user_msg=message,
|
||||
last_assistant_msg=response,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _dispatch_extraction(
|
||||
self,
|
||||
user_id: str,
|
||||
episode_id: str,
|
||||
last_user_msg: str,
|
||||
last_assistant_msg: str,
|
||||
session_id: str | None,
|
||||
) -> None:
|
||||
"""Route extraction to realtime task or batch queue based on user tier."""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
tier = await tier_manager.get_tier(user_id, self._db)
|
||||
|
||||
if tier_manager.check_feature(tier, "realtime_extraction"):
|
||||
# Pro/Power/Team: fire-and-forget in the background.
|
||||
# Must open a fresh session — request session closes after handler returns.
|
||||
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
|
||||
async def _task() -> None:
|
||||
try:
|
||||
async with async_session() as fresh_db:
|
||||
await run_extraction(
|
||||
db=fresh_db,
|
||||
user_id=user_id,
|
||||
last_user_msg=last_user_msg,
|
||||
last_assistant_msg=last_assistant_msg,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory: extraction task failed user=%s: %s", user_id, exc
|
||||
)
|
||||
|
||||
asyncio.create_task(_task())
|
||||
logger.info("memory: realtime extraction dispatched user=%s", user_id)
|
||||
else:
|
||||
# Free tier: enqueue for daily batch cron.
|
||||
queue_row = ExtractionQueue(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
episode_id=episode_id,
|
||||
)
|
||||
self._db.add(queue_row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info(
|
||||
"memory: extraction enqueued (batch) user=%s episode=%s",
|
||||
user_id,
|
||||
episode_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory: extraction queue insert failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await self._db.rollback()
|
||||
|
||||
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
||||
"""Upsert a core memory key/value for a user."""
|
||||
@@ -255,6 +342,143 @@ class MemoryMiddleware:
|
||||
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
||||
return True
|
||||
|
||||
async def store_associative(
|
||||
self,
|
||||
user_id: str,
|
||||
content: str,
|
||||
entity_type: str | None = None,
|
||||
entity_id: str | None = None,
|
||||
) -> None:
|
||||
"""Store associative memory; embed if user tier has real_embeddings."""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
from app.core.embeddings import embed_text # noqa: PLC0415
|
||||
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return
|
||||
|
||||
encrypted = _encrypt(fernet, content)
|
||||
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
user_tier = user_dbg.get("tier") or "free"
|
||||
|
||||
embedding: list[float] | None = None
|
||||
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
||||
embedding = await embed_text(content)
|
||||
|
||||
row = MemoryAssociative(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
content_encrypted=encrypted,
|
||||
embedding=embedding,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
self._db.add(row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info(
|
||||
"memory: store_associative user=%s embedded=%s",
|
||||
user_id,
|
||||
embedding is not None,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: store_associative failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def upsert_relation(
|
||||
self,
|
||||
user_id: str,
|
||||
subject: str,
|
||||
subject_type: str,
|
||||
predicate: str,
|
||||
object_: str,
|
||||
object_type: str,
|
||||
*,
|
||||
confidence: float = 0.7,
|
||||
source_episode_id: str | None = None,
|
||||
notes: str | None = None,
|
||||
) -> None:
|
||||
"""Insert or update a relation row. Matches on (user_id, subject_label, predicate, object_label).
|
||||
|
||||
subject_label / object_label are plaintext entity identifiers — not encrypted.
|
||||
notes is optional; encrypted with user Fernet if provided.
|
||||
"""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
user_tier = user_dbg.get("tier") or "free"
|
||||
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
||||
logger.debug("memory: upsert_relation skipped (tier=%s no relational_memory)", user_tier)
|
||||
return
|
||||
|
||||
notes_encrypted: bytes | None = None
|
||||
if notes:
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet:
|
||||
notes_encrypted = fernet.encrypt(notes.encode())
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryRelation).where(
|
||||
MemoryRelation.user_id == user_id,
|
||||
MemoryRelation.subject_label == subject,
|
||||
MemoryRelation.predicate == predicate,
|
||||
MemoryRelation.object_label == object_,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing is not None:
|
||||
existing.subject_type = subject_type
|
||||
existing.object_type = object_type
|
||||
existing.confidence = confidence
|
||||
existing.last_confirmed_at = _now()
|
||||
if notes_encrypted is not None:
|
||||
existing.notes_encrypted = notes_encrypted
|
||||
else:
|
||||
self._db.add(MemoryRelation(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
subject_label=subject,
|
||||
subject_type=subject_type,
|
||||
predicate=predicate,
|
||||
object_label=object_,
|
||||
object_type=object_type,
|
||||
confidence=confidence,
|
||||
source_episode_id=source_episode_id,
|
||||
notes_encrypted=notes_encrypted,
|
||||
))
|
||||
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info(
|
||||
"memory: upsert_relation user=%s subject=%s predicate=%s object=%s",
|
||||
user_id, subject, predicate, object_,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: upsert_relation failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def query_relations(
|
||||
self,
|
||||
user_id: str,
|
||||
subject: str | None = None,
|
||||
predicate: str | None = None,
|
||||
object_: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[MemoryRelation]:
|
||||
"""Query relation rows for a user with optional filters."""
|
||||
q = select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||
if subject is not None:
|
||||
q = q.where(MemoryRelation.subject_label == subject)
|
||||
if predicate is not None:
|
||||
q = q.where(MemoryRelation.predicate == predicate)
|
||||
if object_ is not None:
|
||||
q = q.where(MemoryRelation.object_label == object_)
|
||||
q = q.order_by(MemoryRelation.confidence.desc()).limit(limit)
|
||||
result = await self._db.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||
"""Insert a long-term archival memory entry."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
@@ -343,13 +567,26 @@ class MemoryMiddleware:
|
||||
|
||||
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
||||
"""Load lightweight user debug fields for trace logs."""
|
||||
from app.config.settings import settings # noqa: PLC0415
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
return {"tier": None}
|
||||
return {
|
||||
"tier": user.tier,
|
||||
}
|
||||
|
||||
sub_result = await self._db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub_tier: str | None = sub_result.scalar_one_or_none()
|
||||
if sub_tier:
|
||||
tier = sub_tier
|
||||
elif settings.ENV == "dev":
|
||||
tier = "power"
|
||||
else:
|
||||
tier = user.tier or "free"
|
||||
|
||||
return {"tier": tier}
|
||||
|
||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||
result = await self._db.execute(
|
||||
@@ -364,14 +601,49 @@ class MemoryMiddleware:
|
||||
return out
|
||||
|
||||
async def _load_associative(
|
||||
self, user_id: str, message: str, fernet: Fernet
|
||||
self, user_id: str, message: str, fernet: Fernet, *, user_tier: str = "free"
|
||||
) -> list[str]:
|
||||
"""Load top-k associative memories.
|
||||
|
||||
Production: uses pgvector cosine similarity on the message embedding.
|
||||
Current implementation: keyword-based fallback (no external embedding call)
|
||||
so tests pass without a live OpenAI key.
|
||||
Pro+: pgvector cosine similarity on the message embedding (real_embeddings feature).
|
||||
Free / embedding failure: keyword-ordered fallback (most recent rows).
|
||||
"""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
from app.core.embeddings import embed_text # noqa: PLC0415
|
||||
|
||||
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
||||
vec = await embed_text(message)
|
||||
if vec is not None:
|
||||
try:
|
||||
result = await self._db.execute(
|
||||
select(MemoryAssociative)
|
||||
.where(
|
||||
MemoryAssociative.user_id == user_id,
|
||||
MemoryAssociative.embedding.isnot(None),
|
||||
)
|
||||
.order_by(MemoryAssociative.embedding.cosine_distance(vec))
|
||||
.limit(_ASSOCIATIVE_TOP_K)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: list[str] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append(plaintext)
|
||||
logger.info(
|
||||
"memory: _load_associative user=%s mode=vector hits=%d",
|
||||
user_id,
|
||||
len(out),
|
||||
)
|
||||
return out
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory: vector search failed user=%s, falling back to keyword: %s",
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
# Keyword fallback: most recent rows
|
||||
result = await self._db.execute(
|
||||
select(MemoryAssociative)
|
||||
.where(MemoryAssociative.user_id == user_id)
|
||||
@@ -379,7 +651,7 @@ class MemoryMiddleware:
|
||||
.limit(_ASSOCIATIVE_TOP_K)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: list[str] = []
|
||||
out = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||
if plaintext is not None:
|
||||
@@ -408,6 +680,26 @@ class MemoryMiddleware:
|
||||
out.append(plaintext)
|
||||
return out
|
||||
|
||||
async def _load_relational(self, user_id: str, *, user_tier: str = "free") -> list[str]:
|
||||
"""Return top-10 relation strings for Pro+ users; empty list for Free."""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
||||
return []
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryRelation)
|
||||
.where(MemoryRelation.user_id == user_id)
|
||||
.order_by(MemoryRelation.confidence.desc())
|
||||
.limit(10)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out = [
|
||||
f"{r.subject_label} --{r.predicate}--> {r.object_label}"
|
||||
for r in rows
|
||||
]
|
||||
return out
|
||||
|
||||
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||
result = await self._db.execute(
|
||||
select(MemoryProactive)
|
||||
|
||||
51
app/core/note_summarizer.py
Normal file
51
app/core/note_summarizer.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Note summarizer — generates a compact AI summary for a note.
|
||||
|
||||
Called fire-and-forget from create_note / update_note tools so the
|
||||
``notes.ai_summary`` column stays current without blocking the agent loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from app.core.langfuse_client import get_prompt_or_fallback
|
||||
from app.core.llm import get_agent_llm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_FALLBACK_PROMPT = """\
|
||||
Summarize this note in <=250 characters. Be terse and dense.
|
||||
Keep proper nouns, dates, decisions, and action items.
|
||||
Do not start with "This note".
|
||||
Respond with the summary text only — no intro, no labels.
|
||||
|
||||
Title: {title}
|
||||
Content: {content}"""
|
||||
|
||||
_MAX_CONTENT_CHARS = 4000
|
||||
|
||||
|
||||
async def generate_note_summary(title: str, content: str) -> str:
|
||||
"""Return a <=250-char summary of *title* + *content*.
|
||||
|
||||
Uses the Langfuse ``note_summary`` prompt (hot-swappable) with a local
|
||||
fallback. Truncates *content* to 4000 chars before sending to avoid
|
||||
token waste on large notes.
|
||||
"""
|
||||
template, _ = get_prompt_or_fallback("note_summary", _FALLBACK_PROMPT)
|
||||
trimmed = content[:_MAX_CONTENT_CHARS]
|
||||
system_prompt = template.format(title=title, content=trimmed)
|
||||
|
||||
try:
|
||||
llm = get_agent_llm("note-summarizer")
|
||||
response = await llm.ainvoke([
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content="Generate the summary."),
|
||||
])
|
||||
text = response.content if isinstance(response.content, str) else ""
|
||||
return text.strip()[:250]
|
||||
except Exception as exc:
|
||||
logger.warning("note_summarizer: failed to generate summary: %s", exc)
|
||||
return ""
|
||||
@@ -2,12 +2,36 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
||||
from app.schemas import WsStreamEnd, WsStreamStart, WsStreamText
|
||||
|
||||
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||
# Matches <canvas kind="...">...</canvas> blocks (single-line or multiline).
|
||||
_CANVAS_BLOCK_RE = re.compile(
|
||||
r'<canvas\s+kind=["\']([^"\']+)["\']>(.*?)</canvas>',
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def extract_canvas_block(text: str) -> tuple[str, str | None, str | None]:
|
||||
"""Strip the first <canvas kind="...">...</canvas> block from *text*.
|
||||
|
||||
Returns ``(visible_text, canvas_content, canvas_kind)``.
|
||||
``canvas_content`` and ``canvas_kind`` are ``None`` when no block is found.
|
||||
"""
|
||||
match = _CANVAS_BLOCK_RE.search(text)
|
||||
if not match:
|
||||
return text, None, None
|
||||
|
||||
canvas_kind = match.group(1).strip()
|
||||
canvas_content = match.group(2).strip()
|
||||
visible = text[: match.start()] + text[match.end() :]
|
||||
visible = visible.strip()
|
||||
return visible, canvas_content, canvas_kind
|
||||
|
||||
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd
|
||||
|
||||
|
||||
class StreamFormatter:
|
||||
@@ -23,14 +47,6 @@ class StreamFormatter:
|
||||
started = False
|
||||
|
||||
async for event_type, data in event_stream:
|
||||
if event_type == "floating_domain":
|
||||
if isinstance(data, dict):
|
||||
yield WsFloatingDomain(
|
||||
request_id=self.request_id,
|
||||
domain=data,
|
||||
)
|
||||
continue
|
||||
|
||||
if event_type != "token":
|
||||
continue
|
||||
|
||||
|
||||
@@ -7,10 +7,32 @@ The callback sends a `tool_call` WS frame and awaits the `tool_result`.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Coroutine
|
||||
from uuid import uuid4
|
||||
|
||||
_SNAKE_TO_CAMEL_RE = re.compile(r"_([a-z])")
|
||||
|
||||
|
||||
def _key_to_camel(key: str) -> str:
|
||||
return _SNAKE_TO_CAMEL_RE.sub(lambda m: m.group(1).upper(), key)
|
||||
|
||||
|
||||
def _keys_to_camel(obj: Any) -> Any:
|
||||
"""Recursively convert dict keys from snake_case to camelCase.
|
||||
|
||||
Mirrors the JS-side ``toCamelCase`` applied to incoming WS frames in
|
||||
``adiuvAI/src/main/api/backend-client.ts``. The Electron executor wraps
|
||||
tool_result payloads in ``toSnakeCase`` before sending; this restores the
|
||||
camelCase schema property names that the tool code expects to read.
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {_key_to_camel(k): _keys_to_camel(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_keys_to_camel(v) for v in obj]
|
||||
return obj
|
||||
|
||||
# Holds the execute callback for the current WS session.
|
||||
# Set by the chat WS handler before the orchestrator runs; cleared after.
|
||||
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
||||
@@ -82,6 +104,7 @@ async def execute_on_client(
|
||||
payload["limit"] = limit
|
||||
|
||||
result = await callback(payload)
|
||||
result = _keys_to_camel(result)
|
||||
collector = _tool_result_collector.get(None)
|
||||
if collector is not None:
|
||||
collector.append({
|
||||
|
||||
83
app/main.py
83
app/main.py
@@ -4,6 +4,10 @@ import logging
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||
from app.config.settings import settings
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
@@ -11,9 +15,66 @@ logging.basicConfig(
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||
|
||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||
from app.config.settings import settings
|
||||
|
||||
async def _memory_audit_cron_tick() -> None:
|
||||
"""Weekly cron: contradiction scan + label canonicalization for all users (Phase 7)."""
|
||||
import logging # noqa: PLC0415
|
||||
_log = logging.getLogger(__name__)
|
||||
_log.info("memory audit cron tick: starting")
|
||||
try:
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
from app.core.memory_maintenance import audit_memory # noqa: PLC0415
|
||||
from app.models import User # noqa: PLC0415
|
||||
from sqlalchemy import select # noqa: PLC0415
|
||||
|
||||
async with async_session() as db:
|
||||
result = await db.execute(select(User.id))
|
||||
user_ids: list[str] = list(result.scalars().all())
|
||||
|
||||
for uid in user_ids:
|
||||
try:
|
||||
async with async_session() as db:
|
||||
await audit_memory(db, uid)
|
||||
except Exception as exc:
|
||||
_log.warning("memory audit cron tick: audit_memory failed user=%s: %s", uid, exc)
|
||||
|
||||
_log.info("memory audit cron tick: done users=%d", len(user_ids))
|
||||
except Exception as exc:
|
||||
_log.warning("memory audit cron tick: failed: %s", exc)
|
||||
|
||||
|
||||
async def _memory_cron_tick() -> None:
|
||||
"""Hourly cron: drain Free-tier extraction queue + mine proactive patterns for Power+ users."""
|
||||
import logging # noqa: PLC0415
|
||||
_log = logging.getLogger(__name__)
|
||||
_log.info("memory cron tick: starting")
|
||||
try:
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
from app.core.memory_maintenance import drain_extraction_queue, mine_proactive_patterns # noqa: PLC0415
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
from app.models import User # noqa: PLC0415
|
||||
from sqlalchemy import select # noqa: PLC0415
|
||||
|
||||
async with async_session() as db:
|
||||
await drain_extraction_queue(db)
|
||||
|
||||
# mine proactive patterns for every Power+ user
|
||||
async with async_session() as db:
|
||||
result = await db.execute(select(User.id))
|
||||
user_ids: list[str] = list(result.scalars().all())
|
||||
|
||||
for uid in user_ids:
|
||||
try:
|
||||
async with async_session() as db:
|
||||
tier = await tier_manager.get_tier(uid, db)
|
||||
if tier_manager.check_feature(tier, "proactive_mining"):
|
||||
await mine_proactive_patterns(db, uid)
|
||||
except Exception as exc:
|
||||
_log.warning("memory cron tick: mine_proactive_patterns failed user=%s: %s", uid, exc)
|
||||
|
||||
_log.info("memory cron tick: done users=%d", len(user_ids))
|
||||
except Exception as exc:
|
||||
_log.warning("memory cron tick: failed: %s", exc)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -21,8 +82,21 @@ async def lifespan(app: FastAPI):
|
||||
# Startup: ensure agent tool modules are loaded.
|
||||
import app.agents # noqa: F401
|
||||
|
||||
scheduler = None
|
||||
if settings.SCHEDULER_ENABLED:
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler # noqa: PLC0415
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron")
|
||||
scheduler.add_job(_memory_audit_cron_tick, "interval", weeks=1, id="memory_audit_cron")
|
||||
scheduler.start()
|
||||
logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)")
|
||||
|
||||
yield
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.shutdown(wait=False)
|
||||
|
||||
# Shutdown: dispose SQLAlchemy connection pool
|
||||
from app.db import engine
|
||||
await engine.dispose()
|
||||
@@ -50,13 +124,14 @@ def create_app() -> FastAPI:
|
||||
app.add_middleware(SanitizerMiddleware)
|
||||
app.add_middleware(TierRateLimitMiddleware)
|
||||
|
||||
from app.api.routes import agents, auth, billing, chat, device_ws
|
||||
from app.api.routes import agents, auth, billing, chat, device_ws, memory
|
||||
|
||||
app.include_router(auth.router, prefix="/api/v1")
|
||||
app.include_router(chat.router, prefix="/api/v1")
|
||||
app.include_router(billing.router, prefix="/api/v1")
|
||||
app.include_router(agents.router, prefix="/api/v1")
|
||||
app.include_router(device_ws.router, prefix="/api/v1")
|
||||
app.include_router(memory.router, prefix="/api/v1")
|
||||
|
||||
@app.get("/api/v1/health", tags=["health"])
|
||||
async def health() -> dict:
|
||||
|
||||
101
app/models.py
101
app/models.py
@@ -14,6 +14,7 @@ Table inventory:
|
||||
memory_associative — per-user semantic memory with embeddings (encrypted)
|
||||
memory_episodic — per-user session summaries (encrypted)
|
||||
memory_proactive — per-user behavioral patterns (encrypted)
|
||||
memory_relations — per-user entity/relation graph (Mem0g-light, Phase 3)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -21,6 +22,7 @@ from __future__ import annotations
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
DateTime,
|
||||
@@ -29,6 +31,7 @@ from sqlalchemy import (
|
||||
ForeignKey,
|
||||
Integer,
|
||||
JSON,
|
||||
LargeBinary,
|
||||
String,
|
||||
Text,
|
||||
Uuid,
|
||||
@@ -240,6 +243,7 @@ class AgentRunLog(Base):
|
||||
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
|
||||
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
|
||||
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||
started_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
@@ -260,6 +264,17 @@ class AgentRunLog(Base):
|
||||
)
|
||||
|
||||
|
||||
class MonthlyTokenUsage(Base):
|
||||
__tablename__ = "monthly_token_usage"
|
||||
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
year_month: Mapped[str] = mapped_column(String(7), primary_key=True) # 'YYYY-MM'
|
||||
feature: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
|
||||
# ── Memory models ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -299,8 +314,8 @@ class MemoryAssociative(Base):
|
||||
nullable=False, index=True,
|
||||
)
|
||||
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
# JSON-encoded float list in SQLite tests; vector(1536) in Postgres via migration.
|
||||
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||
# vector(1536) via pgvector; SQLite tests use NULL embeddings so no dialect issue.
|
||||
embedding: Mapped[list | None] = mapped_column(Vector(1536), nullable=True)
|
||||
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
@@ -348,3 +363,85 @@ class MemoryProactive(Base):
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
class ExtractionQueue(Base):
|
||||
"""Batch extraction queue for Free-tier users (Phase 2).
|
||||
|
||||
Pro/Power/Team users get realtime asyncio.create_task() extraction.
|
||||
Free users get a queue row here; a daily cron (Phase 5) drains it.
|
||||
"""
|
||||
|
||||
__tablename__ = "extraction_queue"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
episode_id: Mapped[str | None] = mapped_column(
|
||||
Uuid(as_uuid=False), nullable=True,
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
class MemoryRelation(Base):
|
||||
"""Per-user entity/relation graph row (Mem0g-light, Phase 3).
|
||||
|
||||
subject_label/object_label are plaintext entity identifiers (not user content).
|
||||
notes_encrypted is optional Fernet-encrypted per-user commentary.
|
||||
confidence in [0.0, 1.0] — decays 5 % per 30 days since last_confirmed_at.
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_relations"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
subject_label: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
subject_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
predicate: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
object_label: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
object_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.7)
|
||||
source_episode_id: Mapped[str | None] = mapped_column(
|
||||
Uuid(as_uuid=False),
|
||||
ForeignKey("memory_episodic.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
notes_encrypted: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
last_confirmed_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class Plugin(Base):
|
||||
"""Plugin marketplace catalog entry."""
|
||||
|
||||
__tablename__ = "plugins"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
version: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
author_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]")
|
||||
status: Mapped[str] = mapped_column(String(50), nullable=False, default="pending")
|
||||
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
@@ -73,11 +73,9 @@ class WsFrameType(str, Enum):
|
||||
device_hello = "device_hello"
|
||||
# ── v3 frame types ─────────────────────────────────────────────────
|
||||
home_request = "home_request"
|
||||
floating_request = "floating_request"
|
||||
stream_start = "stream_start"
|
||||
stream_text = "stream_text"
|
||||
stream_end = "stream_end"
|
||||
floating_domain = "floating_domain"
|
||||
data_request = "data_request"
|
||||
data_response = "data_response"
|
||||
mutation = "mutation"
|
||||
@@ -85,6 +83,21 @@ class WsFrameType(str, Enum):
|
||||
journey_start = "journey_start"
|
||||
journey_message = "journey_message"
|
||||
journey_reply = "journey_reply"
|
||||
# ── v5 brief frame types ──────────────────────────────────────────
|
||||
brief_request = "brief_request"
|
||||
# ── v6 task brief frame types ─────────────────────────────────────
|
||||
task_brief_request = "task_brief_request"
|
||||
# ── v7 folder index frame types ───────────────────────────────────
|
||||
index_session_start = "index_session_start"
|
||||
index_file_batch = "index_file_batch"
|
||||
index_session_cancel = "index_session_cancel"
|
||||
index_file_result = "index_file_result"
|
||||
index_session_progress = "index_session_progress"
|
||||
index_session_done = "index_session_done"
|
||||
# ── v8 contextual sidebar frame types ────────────────────────────
|
||||
contextual_request = "contextual_request"
|
||||
contextual_scope_update = "contextual_scope_update"
|
||||
contextual_scope_ack = "contextual_scope_ack"
|
||||
|
||||
|
||||
class WsToolCall(BaseModel):
|
||||
@@ -140,11 +153,14 @@ class WsDeviceHello(BaseModel):
|
||||
|
||||
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
||||
|
||||
class WsFloatingScope(BaseModel):
|
||||
"""Scope for a floating request — narrows the agent to a specific entity."""
|
||||
class FormatPrefsModel(BaseModel):
|
||||
"""User display preferences sent by Electron on each request."""
|
||||
|
||||
type: Literal["task", "project", "note", "timeline"]
|
||||
id: str | None = None
|
||||
timezone: str = "UTC"
|
||||
date_format: str = "dd/MM/yyyy"
|
||||
time_format: str = "24h"
|
||||
locale: str = "en-US"
|
||||
now_iso: str = ""
|
||||
|
||||
|
||||
class WsHomeRequest(BaseModel):
|
||||
@@ -153,14 +169,18 @@ class WsHomeRequest(BaseModel):
|
||||
type: Literal[WsFrameType.home_request] = WsFrameType.home_request
|
||||
message: str
|
||||
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||
format_prefs: FormatPrefsModel | None = None
|
||||
|
||||
|
||||
class WsFloatingRequest(BaseModel):
|
||||
"""Client → Server: Floating chat message scoped to an entity."""
|
||||
class WsBriefRequest(BaseModel):
|
||||
"""Client → Server: Request a plain-text brief (home or project)."""
|
||||
|
||||
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
|
||||
message: str
|
||||
scope: WsFloatingScope
|
||||
type: Literal[WsFrameType.brief_request] = WsFrameType.brief_request
|
||||
request_id: str | None = None
|
||||
session_id: str | None = None
|
||||
mode: Literal["home", "project"]
|
||||
project_id: str | None = None
|
||||
format_prefs: FormatPrefsModel | None = None
|
||||
|
||||
|
||||
class WsStreamStart(BaseModel):
|
||||
@@ -183,22 +203,8 @@ class WsStreamEnd(BaseModel):
|
||||
|
||||
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||
request_id: str
|
||||
|
||||
|
||||
class WsDomain(BaseModel):
|
||||
"""Structured floating domain payload for UI routing decisions."""
|
||||
|
||||
type: Literal["task", "timeline", "project", "node"]
|
||||
id: str | None = None
|
||||
section: Literal["task", "timeline", "note"] | None = None
|
||||
|
||||
|
||||
class WsFloatingDomain(BaseModel):
|
||||
"""Server → Client: domain determined for a floating request."""
|
||||
|
||||
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
||||
request_id: str
|
||||
domain: WsDomain
|
||||
error: str | None = None
|
||||
mutations: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
# ── Agent Config V2 ───────────────────────────────────────────────────
|
||||
73
app/schemas/contextual.py
Normal file
73
app/schemas/contextual.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Contextual sidebar scope schema and prompt block renderer.
|
||||
|
||||
ContextualScope mirrors the TypeScript ContextualScope type sent by the
|
||||
Electron renderer when the user opens the side chat anchored to a specific
|
||||
view. The renderer ships camelCase keys; Pydantic's alias_generator maps
|
||||
them to snake_case Python attributes automatically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
|
||||
PageType = Literal[
|
||||
"timeline",
|
||||
"tasks",
|
||||
"projects-list",
|
||||
"project",
|
||||
"note",
|
||||
]
|
||||
|
||||
EntityType = Literal["project", "note", "task", "timeline_event"]
|
||||
|
||||
|
||||
class ContextualScope(BaseModel):
|
||||
"""Scope payload sent by the Electron renderer for contextual chat.
|
||||
|
||||
The renderer ships camelCase keys (entityType, entityId, ...). Pydantic's
|
||||
alias generator maps them to snake_case Python attrs.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
|
||||
|
||||
page: PageType
|
||||
entity_type: Optional[EntityType] = None
|
||||
entity_id: Optional[str] = None
|
||||
entity_name: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
char_count: Optional[int] = None
|
||||
counts: Optional[dict[str, int]] = None
|
||||
filters: Optional[dict] = None
|
||||
|
||||
|
||||
def render_scope_block(scope: ContextualScope) -> str:
|
||||
"""Produce a single-paragraph human-readable summary of the current view
|
||||
for injection into the contextual agent system prompt.
|
||||
|
||||
Never emits internal ids — only names. The LLM is told to use names in
|
||||
prose; ids travel through tool calls.
|
||||
"""
|
||||
if scope.entity_type == "project":
|
||||
c = scope.counts or {}
|
||||
return (
|
||||
f"User is viewing the project {scope.entity_name!r}. "
|
||||
f"{c.get('tasks', 0)} tasks, "
|
||||
f"{c.get('notes', 0)} notes, "
|
||||
f"{c.get('milestones', 0)} milestones."
|
||||
)
|
||||
if scope.entity_type == "note":
|
||||
return (
|
||||
f"User is viewing the note {scope.entity_name!r} "
|
||||
f"({scope.char_count or 0} characters)."
|
||||
)
|
||||
if scope.page == "tasks":
|
||||
return "User is viewing the global Tasks list (all projects)."
|
||||
if scope.page == "timeline":
|
||||
return "User is viewing the global Timeline view."
|
||||
if scope.page == "projects-list":
|
||||
return "User is viewing the Projects list."
|
||||
return f"User is on page {scope.page}."
|
||||
@@ -32,8 +32,12 @@ google-auth-oauthlib>=1.2.0
|
||||
google-auth-httplib2>=0.2.0
|
||||
msal>=1.28.0
|
||||
cryptography>=42.0.0
|
||||
langfuse>=2.0.0
|
||||
pgvector>=0.2.5
|
||||
langfuse>=3.3.1
|
||||
beautifulsoup4>=4.12.0
|
||||
lxml>=5.0.0
|
||||
PyYAML>=6.0.0
|
||||
apscheduler>=3.10.0
|
||||
ruff>=0.8.0
|
||||
pypdf>=4.0
|
||||
python-docx>=1.1
|
||||
|
||||
1
results.xml
Normal file
1
results.xml
Normal file
File diff suppressed because one or more lines are too long
@@ -17,6 +17,8 @@ from jose import jwt
|
||||
from sqlalchemy import StaticPool, event
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.db import Base, get_session
|
||||
from app.main import app
|
||||
@@ -134,6 +136,38 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st
|
||||
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
||||
|
||||
|
||||
# ── Convenience aliases and per-tier user fixtures ────────────────────
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db(db_session: AsyncSession) -> AsyncSession:
|
||||
"""Alias for db_session — used by folder quota tests."""
|
||||
return db_session
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user_free(db_session: AsyncSession):
|
||||
"""Return the seeded free-tier User row."""
|
||||
result = await db_session.execute(
|
||||
select(User).where(User.id == TEST_USER_IDS["free"])
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user_power(db_session: AsyncSession):
|
||||
"""Return the seeded power-tier User row."""
|
||||
result = await db_session.execute(
|
||||
select(User).where(User.id == TEST_USER_IDS["power"])
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_headers_free() -> dict[str, str]:
|
||||
"""Authorization header for the seeded free-tier user."""
|
||||
return auth_header("free")
|
||||
|
||||
|
||||
# ── CLI options ───────────────────────────────────────────────────────
|
||||
|
||||
def pytest_addoption(parser):
|
||||
|
||||
@@ -1,808 +0,0 @@
|
||||
"""Tests for Step 3.4: agent_runner module.
|
||||
|
||||
Coverage:
|
||||
Unit:
|
||||
- _is_overdue — cron schedule overdue detection
|
||||
- _extract_items_from_content — LLM extraction + JSON parsing + validation
|
||||
- _send_insert_to_client — tool_call frame construction + timeout
|
||||
- run_local_agent — end-to-end local agent happy path
|
||||
- run_local_agent — device offline path
|
||||
- run_local_agent — file-read timeout path
|
||||
- run_local_agent — LLM extraction error path
|
||||
- run_cloud_agent — stub returns error immediately
|
||||
- trigger_pending_runs — skipped when config is client-owned
|
||||
- trigger_pending_runs — non-overdue skipped
|
||||
- trigger_pending_runs — device_id filter for local agents
|
||||
|
||||
Integration:
|
||||
- POST /agents/can-create — billing eligibility check
|
||||
- POST /agents/trigger — creates run log + dispatches background task
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.agent_runner import (
|
||||
_extract_items_from_content,
|
||||
_is_overdue,
|
||||
_send_insert_to_client,
|
||||
run_cloud_agent,
|
||||
run_local_agent,
|
||||
trigger_pending_runs,
|
||||
)
|
||||
from app.core.device_manager import DeviceConnectionManager
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||
from tests.conftest import TEST_USER_IDS, auth_header
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FREE_UID = TEST_USER_IDS["free"]
|
||||
_PRO_UID = TEST_USER_IDS["pro"]
|
||||
|
||||
|
||||
def _make_local_config(user_id: str = _FREE_UID, device_id: str = "dev-001") -> LocalAgentConfig:
|
||||
return LocalAgentConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
name="Test Local Agent",
|
||||
directory_paths=["/home/user/emails"],
|
||||
data_types=["tasks", "notes"],
|
||||
prompt_template="Extract tasks and notes from this document.",
|
||||
file_extensions=[".txt", ".eml"],
|
||||
schedule_cron="0 */6 * * *",
|
||||
enabled=True,
|
||||
last_run_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_cloud_config(user_id: str = _FREE_UID) -> CloudAgentConfig:
|
||||
return CloudAgentConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
provider="gmail",
|
||||
name="Test Gmail Agent",
|
||||
data_types=["tasks"],
|
||||
prompt_template="Extract tasks from email.",
|
||||
schedule_cron="0 */6 * * *",
|
||||
enabled=True,
|
||||
last_run_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_run_log(agent_id: str, agent_type: str = "local", user_id: str = _FREE_UID) -> AgentRunLog:
|
||||
return AgentRunLog(
|
||||
id=str(uuid.uuid4()),
|
||||
agent_id=agent_id,
|
||||
agent_type=agent_type,
|
||||
user_id=user_id,
|
||||
status="running",
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _make_manager(user_id: str = _FREE_UID, device_id: str = "dev-001") -> DeviceConnectionManager:
|
||||
mgr = DeviceConnectionManager()
|
||||
ws = MagicMock()
|
||||
ws.send_text = AsyncMock()
|
||||
mgr.register(user_id, device_id, ws)
|
||||
return mgr
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_overdue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_is_overdue_never_run():
|
||||
"""An agent that has never run is always overdue."""
|
||||
assert _is_overdue("0 */6 * * *", None) is True
|
||||
|
||||
|
||||
def test_is_overdue_very_recently_run():
|
||||
"""An agent that just ran is not overdue."""
|
||||
last = datetime.now(timezone.utc)
|
||||
assert _is_overdue("0 */6 * * *", last) is False
|
||||
|
||||
|
||||
def test_is_overdue_long_ago():
|
||||
"""An agent last run 2 days ago with a 6-hour schedule is overdue."""
|
||||
from datetime import timedelta
|
||||
last = datetime.now(timezone.utc) - timedelta(days=2)
|
||||
assert _is_overdue("0 */6 * * *", last) is True
|
||||
|
||||
|
||||
def test_is_overdue_invalid_cron_returns_false():
|
||||
"""Unparseable cron must not raise and should return False (fail-safe)."""
|
||||
assert _is_overdue("not a cron", None) is False
|
||||
|
||||
|
||||
def test_is_overdue_naive_datetime():
|
||||
"""Naive datetime objects are handled without raising."""
|
||||
from datetime import timedelta
|
||||
last = datetime.utcnow() - timedelta(days=1) # naive
|
||||
# Should not raise.
|
||||
result = _is_overdue("0 */6 * * *", last)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _extract_items_from_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_happy_path():
|
||||
"""LLM returns valid JSON array; items with allowed tables are returned."""
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = json.dumps([
|
||||
{"table": "tasks", "data": {"title": "Buy milk", "priority": "high"}},
|
||||
{"table": "notes", "data": {"title": "Meeting recap", "content": "Discussed roadmap"}},
|
||||
])
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
items = await _extract_items_from_content(
|
||||
"Extract tasks and notes.",
|
||||
"Email body: Buy milk urgently. Notes from meeting: discussed roadmap.",
|
||||
["tasks", "notes"],
|
||||
)
|
||||
|
||||
assert len(items) == 2
|
||||
assert items[0]["table"] == "tasks"
|
||||
assert items[0]["data"]["title"] == "Buy milk"
|
||||
assert items[1]["table"] == "notes"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_strips_forbidden_fields():
|
||||
"""Fields like id, createdAt, isAiSuggested must be stripped from extracted data."""
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = json.dumps([
|
||||
{
|
||||
"table": "tasks",
|
||||
"data": {
|
||||
"title": "Review PR",
|
||||
"id": "should-be-removed",
|
||||
"createdAt": 99999,
|
||||
"isAiSuggested": 0,
|
||||
"isApproved": 1,
|
||||
},
|
||||
}
|
||||
])
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
items = await _extract_items_from_content("Extract tasks.", "Review the PR.", ["tasks"])
|
||||
|
||||
assert len(items) == 1
|
||||
data = items[0]["data"]
|
||||
assert "id" not in data
|
||||
assert "createdAt" not in data
|
||||
assert "isAiSuggested" not in data
|
||||
assert "isApproved" not in data
|
||||
assert data["title"] == "Review PR"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_invalid_json_returns_empty():
|
||||
"""LLM returning invalid JSON must return empty list without raising."""
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "Sorry, I cannot extract anything."
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
items = await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
||||
|
||||
assert items == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_disallowed_table_filtered():
|
||||
"""Items whose table is not in data_types are discarded."""
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = json.dumps([
|
||||
{"table": "tasks", "data": {"title": "Valid task"}},
|
||||
{"table": "projects", "data": {"name": "Should be filtered"}},
|
||||
])
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
# Only "tasks" is in data_types — "projects" should be filtered.
|
||||
items = await _extract_items_from_content("Extract.", "content", ["tasks"])
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0]["table"] == "tasks"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_empty_data_types_returns_empty():
|
||||
"""If no allowed data_types match, skip LLM call and return immediately."""
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.ainvoke = AsyncMock()
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
items = await _extract_items_from_content("Extract.", "content", [])
|
||||
|
||||
mock_llm.ainvoke.assert_not_called()
|
||||
assert items == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_llm_error_propagates():
|
||||
"""LLM API errors propagate so the caller (run_local_agent) can record them."""
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("API unavailable"))
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
with pytest.raises(RuntimeError, match="API unavailable"):
|
||||
await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_insert_to_client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_insert_to_client_happy_path():
|
||||
"""Frame is sent with isAiSuggested/isApproved added; result is returned."""
|
||||
mgr = _make_manager()
|
||||
|
||||
sent_payloads: list[dict] = []
|
||||
original_send = mgr.send_frame
|
||||
|
||||
async def _capture_send(uid: str, frame: dict) -> None:
|
||||
sent_payloads.append(frame)
|
||||
# Immediately resolve the pending call with a success result.
|
||||
call_id = frame["id"]
|
||||
mgr.resolve_pending_call(uid, call_id, {"row": {"id": "new-id", "title": "Buy milk"}})
|
||||
|
||||
mgr.send_frame = _capture_send # type: ignore[method-assign]
|
||||
|
||||
result = await _send_insert_to_client(
|
||||
_FREE_UID, "tasks", {"title": "Buy milk", "priority": "high"}, mgr
|
||||
)
|
||||
|
||||
assert len(sent_payloads) == 1
|
||||
payload = sent_payloads[0]
|
||||
assert payload["action"] == "insert"
|
||||
assert payload["table"] == "tasks"
|
||||
assert payload["data"]["title"] == "Buy milk"
|
||||
assert payload["data"]["isAiSuggested"] == 1
|
||||
assert payload["data"]["isApproved"] == 0
|
||||
assert result["row"]["title"] == "Buy milk"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_insert_to_client_timeout():
|
||||
"""asyncio.TimeoutError is raised when Electron does not respond."""
|
||||
mgr = _make_manager()
|
||||
|
||||
async def _slow_send(uid: str, frame: dict) -> None:
|
||||
# Never resolve the pending call.
|
||||
pass
|
||||
|
||||
mgr.send_frame = _slow_send # type: ignore[method-assign]
|
||||
|
||||
with patch("app.core.agent_runner._INSERT_TIMEOUT", 0.05):
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await _send_insert_to_client(_FREE_UID, "tasks", {"title": "X"}, mgr)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_local_agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_local_agent_device_offline():
|
||||
"""run_local_agent marks run as error when device is offline."""
|
||||
config = _make_local_config()
|
||||
run_log = _make_run_log(config.id)
|
||||
mgr = DeviceConnectionManager() # Empty — no device registered.
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
mock_finalize.assert_called_once()
|
||||
_args, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert any("not connected" in e for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_local_agent_happy_path():
|
||||
"""End-to-end: files received, LLM extracts one task, insert sent + ack'd."""
|
||||
config = _make_local_config()
|
||||
run_log = _make_run_log(config.id)
|
||||
mgr = _make_manager()
|
||||
|
||||
# Build a fake agent_data frame (will be queued after send).
|
||||
file_frame = {
|
||||
"type": "agent_data",
|
||||
"run_id": run_log.id,
|
||||
"files": [{"path": "/email.eml", "content": "Urgent: fix the bug by Friday."}],
|
||||
}
|
||||
agent_complete_frame = None # sentinel
|
||||
|
||||
sent_frames: list[dict] = []
|
||||
|
||||
async def _mock_send(uid: str, frame: dict) -> None:
|
||||
sent_frames.append(frame)
|
||||
if frame.get("type") == "agent_run":
|
||||
# Simulate Electron responding with file data then agent_complete.
|
||||
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
||||
await q.put(file_frame)
|
||||
await q.put(agent_complete_frame)
|
||||
elif frame.get("type") == "tool_call":
|
||||
# Resolve the pending insert immediately.
|
||||
mgr.resolve_pending_call(uid, frame["id"], {"row": {"id": "new-task", "title": "Fix the bug"}})
|
||||
|
||||
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = json.dumps([
|
||||
{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}
|
||||
])
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
mock_finalize.assert_called_once()
|
||||
_args, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "success"
|
||||
assert kwargs["items_processed"] == 1
|
||||
assert kwargs["items_created"] == 1
|
||||
assert kwargs["errors"] == []
|
||||
assert kwargs["update_config_last_run"] is False
|
||||
|
||||
# Verify agent_run frame was sent.
|
||||
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
||||
assert len(agent_run_frames) == 1
|
||||
assert agent_run_frames[0]["agent_id"] == config.id
|
||||
assert "paths" in agent_run_frames[0]["config"]
|
||||
|
||||
# Verify insert frame was sent with AI flags.
|
||||
insert_frames = [f for f in sent_frames if f.get("type") == "tool_call"]
|
||||
assert len(insert_frames) == 1
|
||||
assert insert_frames[0]["data"]["isAiSuggested"] == 1
|
||||
assert insert_frames[0]["data"]["isApproved"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_local_agent_file_read_timeout():
|
||||
"""run_local_agent marks run as partial/error when device stops sending files."""
|
||||
config = _make_local_config()
|
||||
run_log = _make_run_log(config.id)
|
||||
mgr = _make_manager()
|
||||
|
||||
async def _mock_send(uid: str, frame: dict) -> None:
|
||||
# Don't put anything in the queue — simulate stalled device.
|
||||
pass
|
||||
|
||||
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||
|
||||
with patch("app.core.agent_runner._FILE_READ_TIMEOUT", 0.1), \
|
||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
mock_finalize.assert_called_once()
|
||||
_args, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error" # No items created, so error (not partial).
|
||||
assert any("timed out" in e.lower() for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_local_agent_llm_extraction_error():
|
||||
"""LLM errors per-file are recorded; run continues for remaining files."""
|
||||
config = _make_local_config()
|
||||
run_log = _make_run_log(config.id)
|
||||
mgr = _make_manager()
|
||||
|
||||
file_frame = {
|
||||
"type": "agent_data",
|
||||
"run_id": run_log.id,
|
||||
"files": [
|
||||
{"path": "/file1.eml", "content": "Email one."},
|
||||
{"path": "/file2.eml", "content": "Email two."},
|
||||
],
|
||||
}
|
||||
|
||||
async def _mock_send(uid: str, frame: dict) -> None:
|
||||
if frame.get("type") == "agent_run":
|
||||
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
||||
await q.put(file_frame)
|
||||
await q.put(None) # agent_complete sentinel
|
||||
|
||||
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM boom"))
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
_args, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert kwargs["items_processed"] == 2 # Both files attempted.
|
||||
assert kwargs["items_created"] == 0
|
||||
assert len(kwargs["errors"]) == 2 # One error per file.
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_cloud_agent (stub)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_device_offline():
|
||||
"""Cloud agent aborts immediately when no device is connected."""
|
||||
config = _make_cloud_config()
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = DeviceConnectionManager() # empty — no devices registered
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
mock_finalize.assert_called_once()
|
||||
_, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert any("device" in e.lower() or "connected" in e.lower() for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_no_oauth_token():
|
||||
"""Cloud agent errors when no OAuth token is stored."""
|
||||
config = _make_cloud_config()
|
||||
config.oauth_token_encrypted = None
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = _make_manager()
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
_, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert any("oauth" in e.lower() or "token" in e.lower() for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_token_decrypt_failure():
|
||||
"""Cloud agent errors gracefully when the stored token cannot be decrypted."""
|
||||
config = _make_cloud_config()
|
||||
config.oauth_token_encrypted = "this-is-not-valid-fernet-ciphertext"
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = _make_manager()
|
||||
|
||||
from cryptography.fernet import Fernet as _Fernet
|
||||
valid_key = _Fernet.generate_key().decode()
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||
patch("app.integrations.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ENCRYPTION_KEY = valid_key
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
_, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert any("decrypt" in e.lower() for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_happy_path_gmail():
|
||||
"""Cloud agent happy path: Gmail fetch → LLM extraction → inserts → success."""
|
||||
from app.integrations import EmailMessage, encrypt_token
|
||||
from cryptography.fernet import Fernet as _Fernet
|
||||
|
||||
fernet_key = _Fernet.generate_key().decode()
|
||||
credentials = {
|
||||
"token": "access_abc",
|
||||
"refresh_token": "refresh_xyz",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"client_id": "cid",
|
||||
"client_secret": "csec",
|
||||
}
|
||||
|
||||
config = _make_cloud_config()
|
||||
config.provider = "gmail"
|
||||
config.prompt_template = "Extract tasks from this email."
|
||||
config.data_types = ["tasks"]
|
||||
|
||||
with patch("app.integrations.settings") as ms:
|
||||
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||
config.oauth_token_encrypted = encrypt_token(credentials)
|
||||
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = _make_manager()
|
||||
|
||||
sample_email = EmailMessage(
|
||||
id="msg001",
|
||||
subject="Action required",
|
||||
sender="boss@company.com",
|
||||
body_text="Please fix the bug by Friday.",
|
||||
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
extracted_items = [{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}]
|
||||
|
||||
with patch("app.integrations.settings") as mock_int_settings, \
|
||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||
patch("app.core.agent_runner._extract_items_from_content", new_callable=AsyncMock, return_value=extracted_items) as mock_extract, \
|
||||
patch("app.core.agent_runner._send_insert_to_client", new_callable=AsyncMock, return_value={"ok": True}) as mock_insert, \
|
||||
patch("app.core.agent_runner.async_session"):
|
||||
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||
|
||||
mock_gmail = AsyncMock()
|
||||
mock_gmail.fetch_messages = AsyncMock(return_value=[sample_email])
|
||||
mock_gmail.refreshed_credentials = None
|
||||
|
||||
with patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||
patch("app.integrations.get_provider", return_value=mock_gmail):
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
mock_extract.assert_called_once()
|
||||
mock_insert.assert_called_once()
|
||||
_, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "success"
|
||||
assert kwargs["items_processed"] == 1
|
||||
assert kwargs["items_created"] == 1
|
||||
assert kwargs["config_type"] == "cloud"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_provider_fetch_error():
|
||||
"""Cloud agent records error status when provider fetch raises RuntimeError."""
|
||||
credentials = {"token": "abc"}
|
||||
config = _make_cloud_config()
|
||||
config.oauth_token_encrypted = "some_encrypted_value" # non-empty so decrypt step is reached
|
||||
config.prompt_template = "Extract tasks."
|
||||
config.data_types = ["tasks"]
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = _make_manager()
|
||||
|
||||
mock_provider = AsyncMock()
|
||||
mock_provider.fetch_messages = AsyncMock(side_effect=RuntimeError("API quota exceeded"))
|
||||
mock_provider.refreshed_credentials = None
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||
patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||
patch("app.integrations.get_provider", return_value=mock_provider), \
|
||||
patch("app.core.agent_runner.async_session"):
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
_, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert any("quota" in e.lower() or "fetch" in e.lower() for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_refreshed_token_persisted():
|
||||
"""When the provider refreshes its token, the new ciphertext is written to DB."""
|
||||
from app.integrations import encrypt_token
|
||||
from cryptography.fernet import Fernet as _Fernet
|
||||
|
||||
fernet_key = _Fernet.generate_key().decode()
|
||||
credentials = {"token": "old_token", "refresh_token": "rt_old"}
|
||||
fresh_credentials = {"token": "new_token", "refresh_token": "rt_new"}
|
||||
|
||||
config = _make_cloud_config()
|
||||
config.prompt_template = "Extract tasks."
|
||||
config.data_types = ["tasks"]
|
||||
|
||||
with patch("app.integrations.settings") as ms:
|
||||
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||
config.oauth_token_encrypted = encrypt_token(credentials)
|
||||
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = _make_manager()
|
||||
|
||||
mock_provider = AsyncMock()
|
||||
mock_provider.fetch_messages = AsyncMock(return_value=[])
|
||||
mock_provider.refreshed_credentials = fresh_credentials # token was refreshed
|
||||
|
||||
# Track DB writes via mock async_session.
|
||||
mock_cfg_row = MagicMock()
|
||||
mock_cfg_row.oauth_token_encrypted = None
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
||||
mock_db.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_db.scalar_one_or_none = AsyncMock(return_value=mock_cfg_row)
|
||||
cfg_result = MagicMock()
|
||||
cfg_result.scalar_one_or_none.return_value = mock_cfg_row
|
||||
mock_db.execute = AsyncMock(return_value=cfg_result)
|
||||
mock_db.commit = AsyncMock()
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock), \
|
||||
patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||
patch("app.integrations.get_provider", return_value=mock_provider), \
|
||||
patch("app.integrations.encrypt_token", return_value="new_encrypted") as mock_encrypt, \
|
||||
patch("app.core.agent_runner.async_session", return_value=mock_db), \
|
||||
patch("app.integrations.settings") as mock_int_settings:
|
||||
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
# The new encrypted token should have been written to the config row.
|
||||
mock_encrypt.assert_called_once_with(fresh_credentials)
|
||||
assert mock_cfg_row.oauth_token_encrypted == "new_encrypted"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finalize_run_updates_cloud_config_last_run_at():
|
||||
"""_finalize_run with config_type='cloud' updates CloudAgentConfig.last_run_at."""
|
||||
from app.core.agent_runner import _finalize_run
|
||||
|
||||
run_log = _make_run_log(str(uuid.uuid4()), agent_type="cloud")
|
||||
run_log.id = str(uuid.uuid4())
|
||||
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.last_run_at = None
|
||||
|
||||
cfg_result = MagicMock()
|
||||
cfg_result.scalar_one_or_none.return_value = mock_cfg
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
||||
mock_db.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_db.merge = AsyncMock(return_value=run_log)
|
||||
mock_db.execute = AsyncMock(return_value=cfg_result)
|
||||
mock_db.commit = AsyncMock()
|
||||
|
||||
config_id = str(uuid.uuid4())
|
||||
|
||||
with patch("app.core.agent_runner.async_session", return_value=mock_db):
|
||||
await _finalize_run(
|
||||
run_log,
|
||||
status="success",
|
||||
update_config_last_run=True,
|
||||
config_id=config_id,
|
||||
config_type="cloud",
|
||||
)
|
||||
|
||||
# CloudAgentConfig.last_run_at should have been set.
|
||||
assert mock_cfg.last_run_at is not None
|
||||
mock_db.commit.assert_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# trigger_pending_runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_pending_runs_no_overdue():
|
||||
"""Pending-run scan is skipped because agent config is client-owned."""
|
||||
|
||||
mgr = _make_manager()
|
||||
|
||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||
|
||||
mock_run.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_pending_runs_device_id_filter():
|
||||
"""Device filtering is no longer backend-managed in pending runs."""
|
||||
|
||||
mgr = _make_manager(device_id="dev-001")
|
||||
|
||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||
|
||||
mock_run.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_pending_runs_dispatches_overdue():
|
||||
"""No pending runs are dispatched by backend after config deprecation."""
|
||||
|
||||
mgr = _make_manager()
|
||||
|
||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||
|
||||
mock_run.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: POST /agents/can-create and /agents/trigger
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _override_db(db_session):
|
||||
"""Route all get_session calls to the test SQLite session."""
|
||||
|
||||
async def _gen():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _gen
|
||||
yield
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_create_agent_allows_when_under_limit(client):
|
||||
"""POST /agents/can-create returns allowed=True when under tier limit."""
|
||||
resp = client.post(
|
||||
"/api/v1/agents/can-create",
|
||||
json={"active_agents": 0},
|
||||
headers=auth_header("free"),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["allowed"] is True
|
||||
assert body["tier"] == "free"
|
||||
assert body["active_agents"] == 0
|
||||
assert body["limit"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_create_agent_denies_when_at_limit(client):
|
||||
"""POST /agents/can-create returns allowed=False at free-tier limit."""
|
||||
resp = client.post(
|
||||
"/api/v1/agents/can-create",
|
||||
json={"active_agents": 2},
|
||||
headers=auth_header("free"),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["allowed"] is False
|
||||
assert body["limit"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
||||
"""POST /agents/trigger creates a local run log and dispatches background task."""
|
||||
dispatched: list[tuple[str, str]] = []
|
||||
|
||||
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
||||
dispatched.append((user_id, cfg.id))
|
||||
|
||||
def _fake_create_task(coro):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
||||
patch("asyncio.create_task") as mock_create_task:
|
||||
mock_create_task.side_effect = _fake_create_task
|
||||
resp = client.post(
|
||||
"/api/v1/agents/trigger",
|
||||
json={
|
||||
"directory": "/home/user/docs",
|
||||
"what_to_extract": ["task", "note"],
|
||||
"batch_interval": "0 */6 * * *",
|
||||
"custom_agent_prompt": "Extract tasks and notes.",
|
||||
"active_agents": 0,
|
||||
},
|
||||
headers=auth_header("power"),
|
||||
)
|
||||
|
||||
assert resp.status_code == 202
|
||||
data = resp.json()
|
||||
assert isinstance(data["agent_id"], str)
|
||||
assert data["agent_id"]
|
||||
assert data["status"] == "running"
|
||||
assert data["agent_type"] == "local"
|
||||
|
||||
# Verify create_task was called (dispatching background run).
|
||||
mock_create_task.assert_called_once()
|
||||
@@ -382,7 +382,6 @@ async def test_eval_runner(runner_case, pytestconfig):
|
||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
||||
|
||||
_, kwargs = mock_fin.call_args
|
||||
inserts = [c for c in calls if c["action"] == "insert"]
|
||||
score, comment = _evaluate_case(case, calls, kwargs)
|
||||
|
||||
if obs is not None:
|
||||
|
||||
@@ -1,242 +0,0 @@
|
||||
"""Tests for the Chatbot Journey endpoints.
|
||||
|
||||
Covers:
|
||||
1. Start journey for local agent → session_id + first question, done=False
|
||||
2. Start journey for cloud agent → contextual email-focused question
|
||||
3. Start journey with existing agent_id → session seeded, first question returned
|
||||
4. Start journey with non-existent agent_id → still succeeds (graceful fallback)
|
||||
5. Message: continue conversation → done=False, follow-up question returned
|
||||
6. Message: LLM wraps up → done=True + prompt_template extracted correctly
|
||||
7. Message with max-turns nudge → no crash, returns response
|
||||
8. Invalid session_id → 404
|
||||
9. Expired session → 404
|
||||
10. Session ownership: user B cannot access user A's session
|
||||
11. No JWT on /start → 401
|
||||
12. No JWT on /message → 401
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.routes.agent_setup import (
|
||||
_SESSION_TTL_SECONDS,
|
||||
_TEMPLATE_END,
|
||||
_TEMPLATE_START,
|
||||
_extract_template,
|
||||
_sessions,
|
||||
)
|
||||
from app.models import LocalAgentConfig
|
||||
from tests.conftest import TEST_USER_IDS, auth_header
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _start(client: TestClient, agent_type: str = "local", agent_id: str | None = None, tier: str = "power") -> dict:
|
||||
body: dict = {"agent_type": agent_type}
|
||||
if agent_id:
|
||||
body["agent_id"] = agent_id
|
||||
resp = client.post("/api/v1/agents/journey/start", json=body, headers=auth_header(tier))
|
||||
return resp
|
||||
|
||||
|
||||
def _message(client: TestClient, session_id: str, message: str, tier: str = "power") -> dict:
|
||||
return client.post(
|
||||
"/api/v1/agents/journey/message",
|
||||
json={"session_id": session_id, "message": message},
|
||||
headers=auth_header(tier),
|
||||
)
|
||||
|
||||
|
||||
# ── Unit: _extract_template ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_extract_template_present():
|
||||
text = f"Some preamble.\n{_TEMPLATE_START}\nExtract tasks from emails.\n{_TEMPLATE_END}\nTrailing text."
|
||||
result = _extract_template(text)
|
||||
assert result == "Extract tasks from emails."
|
||||
|
||||
|
||||
def test_extract_template_absent():
|
||||
assert _extract_template("No markers here.") is None
|
||||
|
||||
|
||||
def test_extract_template_empty_content():
|
||||
text = f"{_TEMPLATE_START}\n{_TEMPLATE_END}"
|
||||
assert _extract_template(text) is None
|
||||
|
||||
|
||||
# ── Start journey ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_start_journey_local(client: TestClient):
|
||||
resp = _start(client, agent_type="local")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "session_id" in body
|
||||
assert body["done"] is False
|
||||
assert body["prompt_template"] is None
|
||||
assert len(body["message"]) > 0
|
||||
# Local question should be about files/directories
|
||||
assert any(w in body["message"].lower() for w in ("file", "director", "document", "monitor"))
|
||||
|
||||
|
||||
def test_start_journey_cloud(client: TestClient):
|
||||
resp = _start(client, agent_type="cloud")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["done"] is False
|
||||
# Cloud question should mention emails or messages
|
||||
assert any(w in body["message"].lower() for w in ("email", "message", "communication"))
|
||||
|
||||
|
||||
def test_start_journey_with_agent_id(client: TestClient, db_session: AsyncSession):
|
||||
"""When agent_id is provided, session should be created even if agent doesn't exist."""
|
||||
fake_agent_id = str(uuid.uuid4())
|
||||
resp = _start(client, agent_type="local", agent_id=fake_agent_id)
|
||||
# Should succeed gracefully even if the agent_id doesn't exist
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["done"] is False
|
||||
|
||||
|
||||
def test_start_journey_with_existing_agent(client: TestClient, db_session: AsyncSession):
|
||||
"""When a real local agent is provided, session is seeded with its prompt_template."""
|
||||
import asyncio
|
||||
|
||||
user_id = TEST_USER_IDS["power"]
|
||||
agent = LocalAgentConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
name="Test Agent",
|
||||
device_id="device-1",
|
||||
directory_paths=["/home/user/emails"],
|
||||
data_types=["tasks"],
|
||||
prompt_template="Extract tasks from .eml files.",
|
||||
file_extensions=[".eml"],
|
||||
schedule_cron="0 */6 * * *",
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
async def _seed():
|
||||
db_session.add(agent)
|
||||
await db_session.commit()
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(_seed())
|
||||
|
||||
resp = _start(client, agent_type="local", agent_id=agent.id)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["done"] is False
|
||||
# The session should be stored
|
||||
assert body["session_id"] in _sessions
|
||||
|
||||
|
||||
def test_start_journey_requires_auth(client: TestClient):
|
||||
resp = client.post("/api/v1/agents/journey/start", json={"agent_type": "local"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# ── Message ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_message_continues_conversation(client: TestClient):
|
||||
"""A mid-journey reply (no template markers) returns done=False."""
|
||||
follow_up = "That looks good. Can you tell me more about priority rules?"
|
||||
|
||||
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=follow_up)):
|
||||
start_resp = _start(client, agent_type="local")
|
||||
assert start_resp.status_code == 200
|
||||
session_id = start_resp.json()["session_id"]
|
||||
|
||||
msg_resp = _message(client, session_id, "I have .eml and .txt files")
|
||||
assert msg_resp.status_code == 200
|
||||
body = msg_resp.json()
|
||||
assert body["done"] is False
|
||||
assert body["prompt_template"] is None
|
||||
assert body["message"] == follow_up
|
||||
assert body["session_id"] == session_id
|
||||
|
||||
|
||||
def test_message_produces_template(client: TestClient):
|
||||
"""When the LLM includes PROMPT_TEMPLATE markers, done=True and prompt_template is set."""
|
||||
final_template = "Extract tasks from email. Subject → title. 'urgent' → high priority."
|
||||
llm_response = (
|
||||
"Great, I have all the information I need.\n"
|
||||
f"{_TEMPLATE_START}\n{final_template}\n{_TEMPLATE_END}\n"
|
||||
)
|
||||
|
||||
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=llm_response)):
|
||||
start_resp = _start(client, agent_type="cloud")
|
||||
assert start_resp.status_code == 200
|
||||
session_id = start_resp.json()["session_id"]
|
||||
|
||||
msg_resp = _message(client, session_id, "Only invoices from clients")
|
||||
assert msg_resp.status_code == 200
|
||||
body = msg_resp.json()
|
||||
assert body["done"] is True
|
||||
assert body["prompt_template"] == final_template
|
||||
# Session should be cleaned up
|
||||
assert session_id not in _sessions
|
||||
|
||||
|
||||
def test_message_invalid_session(client: TestClient):
|
||||
resp = _message(client, "nonexistent-session-id", "hello")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_message_wrong_owner(client: TestClient):
|
||||
"""User B cannot access user A's session."""
|
||||
start_resp = _start(client, agent_type="local", tier="power")
|
||||
session_id = start_resp.json()["session_id"]
|
||||
|
||||
# user with "pro" tier (different user_id) tries to send a message
|
||||
resp = client.post(
|
||||
"/api/v1/agents/journey/message",
|
||||
json={"session_id": session_id, "message": "hello"},
|
||||
headers=auth_header("pro"), # different user
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_message_expired_session(client: TestClient):
|
||||
"""Expired sessions return 404."""
|
||||
start_resp = _start(client, agent_type="local")
|
||||
session_id = start_resp.json()["session_id"]
|
||||
|
||||
# Manually expire the session
|
||||
_sessions[session_id].created_at = time.monotonic() - _SESSION_TTL_SECONDS - 1
|
||||
|
||||
resp = _message(client, session_id, "hello")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_message_requires_auth(client: TestClient):
|
||||
resp = client.post(
|
||||
"/api/v1/agents/journey/message",
|
||||
json={"session_id": "any", "message": "hello"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
def test_message_max_turns_nudge(client: TestClient):
|
||||
"""After _MAX_TURNS user messages, a system nudge is appended but no crash occurs."""
|
||||
from app.api.routes.agent_setup import _MAX_TURNS
|
||||
|
||||
follow_up = "Tell me more about priority rules."
|
||||
|
||||
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=follow_up)):
|
||||
start_resp = _start(client, agent_type="local")
|
||||
session_id = start_resp.json()["session_id"]
|
||||
|
||||
for i in range(_MAX_TURNS):
|
||||
resp = _message(client, session_id, f"Answer {i + 1}")
|
||||
assert resp.status_code == 200
|
||||
# While no template produced, session must still exist
|
||||
if resp.json()["done"]:
|
||||
break # LLM decided to wrap up early — also fine
|
||||
163
tests/test_brief_agent.py
Normal file
163
tests/test_brief_agent.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Tests for Phase 3: brief agent WS frame + REST fallback.
|
||||
|
||||
Coverage:
|
||||
- run_home_brief streams non-empty text (mocked _run_single_agent_stream)
|
||||
- run_project_brief with bogus UUID → WS returns stream_end with error, no crash
|
||||
- _build_read_tools uses read-only subset only (no mutating tools)
|
||||
- POST /chat/brief home mode returns {response: "..."}
|
||||
- POST /chat/brief project mode with invalid UUID → 422
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.conftest import TEST_USER_IDS, auth_header
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_USER_ID = TEST_USER_IDS["pro"]
|
||||
_EMPTY_CONTEXT: dict[str, Any] = {"core_memory": {}}
|
||||
|
||||
|
||||
async def _fake_token_stream(*_args, **_kwargs) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""Fake _run_single_agent_stream that yields two token events."""
|
||||
yield ("token", "Hello")
|
||||
yield ("token", " world")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit: run_home_brief streams non-empty text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_home_brief_streams_text():
|
||||
with patch(
|
||||
"app.core.brief_agent._run_single_agent_stream",
|
||||
side_effect=_fake_token_stream,
|
||||
):
|
||||
from app.core.brief_agent import run_home_brief
|
||||
|
||||
chunks: list[str] = []
|
||||
async for event_type, data in run_home_brief(_USER_ID, _EMPTY_CONTEXT):
|
||||
if event_type == "token":
|
||||
chunks.append(str(data))
|
||||
|
||||
assert "".join(chunks) == "Hello world"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit: run_project_brief streams text with valid UUID
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_project_brief_streams_text():
|
||||
project_id = str(uuid.uuid4())
|
||||
with patch(
|
||||
"app.core.brief_agent._run_single_agent_stream",
|
||||
side_effect=_fake_token_stream,
|
||||
):
|
||||
from app.core.brief_agent import run_project_brief
|
||||
|
||||
chunks: list[str] = []
|
||||
async for event_type, data in run_project_brief(_USER_ID, project_id, _EMPTY_CONTEXT):
|
||||
if event_type == "token":
|
||||
chunks.append(str(data))
|
||||
|
||||
assert "".join(chunks) == "Hello world"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit: _build_read_tools uses read-only subset (no write tools)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_build_read_tools_read_only_subset():
|
||||
from app.agents.note_agent import NOTE_READ_TOOLS
|
||||
from app.agents.project_agent import PROJECT_READ_TOOLS
|
||||
from app.agents.task_agent import TASK_READ_TOOLS
|
||||
from app.agents.timeline_agent import TIMELINE_READ_TOOLS
|
||||
from app.core.brief_agent import _build_read_tools
|
||||
|
||||
tools = _build_read_tools(_USER_ID, None)
|
||||
tool_names = {getattr(t, "name", None) or getattr(t, "__name__", str(t)) for t in tools}
|
||||
|
||||
# Read-only exports must be present.
|
||||
for read_list in (TASK_READ_TOOLS, PROJECT_READ_TOOLS, TIMELINE_READ_TOOLS, NOTE_READ_TOOLS):
|
||||
for t in read_list:
|
||||
name = getattr(t, "name", None) or getattr(t, "__name__", str(t))
|
||||
assert name in tool_names, f"Read tool {name!r} missing from _build_read_tools"
|
||||
|
||||
# No mutating tools (e.g. create_task, update_task, delete_task).
|
||||
mutating = {"create_task", "update_task", "delete_task", "create_project",
|
||||
"update_project", "delete_project", "create_note", "update_note",
|
||||
"delete_note", "memory_add", "memory_update", "memory_delete"}
|
||||
overlap = tool_names & mutating
|
||||
assert not overlap, f"Mutating tools in brief read-only subset: {overlap}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: POST /chat/brief — home mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _override_db(db_session):
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
|
||||
async def _gen():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _gen
|
||||
yield
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_brief_home_returns_response(client):
|
||||
async def _fake_home_brief(user_id, context):
|
||||
yield ("token", "Today looks light.")
|
||||
|
||||
with (
|
||||
patch("app.api.routes.chat.run_home_brief", side_effect=_fake_home_brief),
|
||||
patch(
|
||||
"app.api.routes.chat.MemoryMiddleware.enrich_context",
|
||||
new=AsyncMock(return_value={}),
|
||||
),
|
||||
):
|
||||
res = client.post(
|
||||
"/api/v1/chat/brief",
|
||||
json={"mode": "home"},
|
||||
headers=auth_header("pro"),
|
||||
)
|
||||
|
||||
assert res.status_code == 200
|
||||
data = res.json()
|
||||
assert data["response"] == "Today looks light."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_brief_project_invalid_uuid_returns_422(client):
|
||||
res = client.post(
|
||||
"/api/v1/chat/brief",
|
||||
json={"mode": "project", "project_id": "not-a-uuid"},
|
||||
headers=auth_header("pro"),
|
||||
)
|
||||
assert res.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rest_brief_project_missing_uuid_returns_422(client):
|
||||
res = client.post(
|
||||
"/api/v1/chat/brief",
|
||||
json={"mode": "project"},
|
||||
headers=auth_header("pro"),
|
||||
)
|
||||
assert res.status_code == 422
|
||||
@@ -1,184 +0,0 @@
|
||||
"""Unit tests for Step 1 file classification (_classify_file).
|
||||
|
||||
These tests call the real LLM so they require OPENAI_API_KEY / LLM env vars.
|
||||
Run with: pytest tests/test_classify_file.py -v
|
||||
|
||||
To run a quick manual check against a real file without the full UI:
|
||||
python -m tests.test_classify_file <path/to/file.txt> [project_name...]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.agent_runner import _classify_file
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────
|
||||
|
||||
PROJECTS_SAMPLE = [
|
||||
{
|
||||
"id": "aaaa-0001-0000-0000-000000000001",
|
||||
"name": "ARPA Sicilia POC",
|
||||
"status": "active",
|
||||
"aiSummary": "Proof of concept for AI features targeting ARPA Sicilia agency.",
|
||||
},
|
||||
{
|
||||
"id": "bbbb-0002-0000-0000-000000000002",
|
||||
"name": "SNAM AI Meeting Prep",
|
||||
"status": "active",
|
||||
"aiSummary": "AI-assisted preparation of meeting materials for SNAM.",
|
||||
},
|
||||
{
|
||||
"id": "cccc-0003-0000-0000-000000000003",
|
||||
"name": "SFERA+ Wave 2",
|
||||
"status": "active",
|
||||
"aiSummary": "Second wave of the SFERA+ whitelist project.",
|
||||
},
|
||||
]
|
||||
|
||||
ARPA_EMAIL = """\
|
||||
to: roberto.musso@hpe.com; luca.tondin@hpecds.com
|
||||
isImportance: normal
|
||||
hasAttachment: True
|
||||
---
|
||||
## Body
|
||||
Buongiorno,
|
||||
|
||||
In riferimento alla riunione di ieri sul POC ARPA Sicilia, vi invio il riassunto
|
||||
dei deliverable concordati:
|
||||
- Preparare demo entro il 30 marzo
|
||||
- Condividere documentazione tecnica con il team ARPA
|
||||
- Fissare call di follow-up la prossima settimana
|
||||
|
||||
Cordiali saluti
|
||||
Roberto Marchetti
|
||||
"""
|
||||
|
||||
SNAM_EMAIL = """\
|
||||
to: roberto.musso@hpe.com
|
||||
isImportance: high
|
||||
hasAttachment: False
|
||||
---
|
||||
## Body
|
||||
Ciao,
|
||||
ti invio l'agenda per la riunione SNAM di domani.
|
||||
Per favore conferma la tua presenza.
|
||||
"""
|
||||
|
||||
UNRELATED_EMAIL = """\
|
||||
to: roberto.musso@hpe.com
|
||||
isImportance: normal
|
||||
---
|
||||
## Body
|
||||
Benvenuto nel programma HPE Employee Learning Series.
|
||||
Completa la formazione richiesta entro la fine del trimestre.
|
||||
"""
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_arpa_matches_existing():
|
||||
project_id, domains, new_name = await _classify_file(
|
||||
file_path="arpa_email.txt",
|
||||
file_content=ARPA_EMAIL,
|
||||
projects=PROJECTS_SAMPLE,
|
||||
config_data_types=["tasks", "notes", "timelines"],
|
||||
)
|
||||
assert project_id == "aaaa-0001-0000-0000-000000000001", (
|
||||
f"Expected ARPA project, got project_id={project_id!r} new_name={new_name!r}"
|
||||
)
|
||||
assert new_name is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_snam_matches_existing():
|
||||
project_id, domains, new_name = await _classify_file(
|
||||
file_path="snam_email.txt",
|
||||
file_content=SNAM_EMAIL,
|
||||
projects=PROJECTS_SAMPLE,
|
||||
config_data_types=["tasks", "notes"],
|
||||
)
|
||||
assert project_id == "bbbb-0002-0000-0000-000000000002", (
|
||||
f"Expected SNAM project, got project_id={project_id!r} new_name={new_name!r}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_unrelated_returns_new():
|
||||
project_id, domains, new_name = await _classify_file(
|
||||
file_path="learning_email.txt",
|
||||
file_content=UNRELATED_EMAIL,
|
||||
projects=PROJECTS_SAMPLE,
|
||||
config_data_types=["tasks", "notes"],
|
||||
)
|
||||
assert project_id == "new"
|
||||
assert new_name is not None # LLM should suggest a name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_empty_file_returns_new():
|
||||
project_id, domains, new_name = await _classify_file(
|
||||
file_path="empty.txt",
|
||||
file_content=" ",
|
||||
projects=PROJECTS_SAMPLE,
|
||||
config_data_types=["tasks"],
|
||||
)
|
||||
assert project_id == "new"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_no_projects_returns_new():
|
||||
project_id, domains, new_name = await _classify_file(
|
||||
file_path="arpa_email.txt",
|
||||
file_content=ARPA_EMAIL,
|
||||
projects=[],
|
||||
config_data_types=["tasks", "notes"],
|
||||
)
|
||||
assert project_id == "new"
|
||||
assert new_name is not None
|
||||
|
||||
|
||||
# ── CLI quick-test runner ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _cli_test(file_path: str, project_names: list[str]) -> None:
|
||||
"""Run Step 1 classification against a real file from the CLI."""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
content = Path(file_path).read_text(encoding="utf-8", errors="replace")
|
||||
projects = [
|
||||
{"id": f"test-id-{i:04d}", "name": name, "status": "active", "aiSummary": ""}
|
||||
for i, name in enumerate(project_names)
|
||||
]
|
||||
|
||||
print(f"\nClassifying: {file_path}")
|
||||
print(f"Projects in context: {[p['name'] for p in projects]}\n")
|
||||
|
||||
project_id, domains, new_name = await _classify_file(
|
||||
file_path=file_path,
|
||||
file_content=content,
|
||||
projects=projects,
|
||||
config_data_types=["tasks", "notes", "timelines"],
|
||||
)
|
||||
|
||||
result = {
|
||||
"project_id": project_id,
|
||||
"matched_name": next((p["name"] for p in projects if p["id"] == project_id), None),
|
||||
"new_project_name": new_name,
|
||||
"domains": domains,
|
||||
}
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python -m tests.test_classify_file <file_path> [project_name ...]")
|
||||
sys.exit(1)
|
||||
asyncio.run(_cli_test(sys.argv[1], sys.argv[2:]))
|
||||
52
tests/test_contextual_scope.py
Normal file
52
tests/test_contextual_scope.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import pytest
|
||||
from app.schemas.contextual import ContextualScope, render_scope_block
|
||||
|
||||
|
||||
def test_render_project_scope():
|
||||
scope = ContextualScope(
|
||||
page="project",
|
||||
entity_type="project",
|
||||
entity_id="p1",
|
||||
entity_name="Acme Q3 launch",
|
||||
counts={"tasks": 12, "notes": 4, "milestones": 3},
|
||||
)
|
||||
block = render_scope_block(scope)
|
||||
assert "Acme Q3 launch" in block
|
||||
assert "12 tasks" in block
|
||||
assert "4 notes" in block
|
||||
assert "3 milestones" in block
|
||||
assert "p1" not in block
|
||||
|
||||
|
||||
def test_render_list_scope_no_entity():
|
||||
scope = ContextualScope(page="tasks", entity_type=None)
|
||||
block = render_scope_block(scope)
|
||||
assert "tasks" in block.lower()
|
||||
assert "None" not in block
|
||||
|
||||
|
||||
def test_render_note_scope_includes_char_count():
|
||||
scope = ContextualScope(
|
||||
page="note",
|
||||
entity_type="note",
|
||||
entity_id="n1",
|
||||
entity_name="Meeting 14 May",
|
||||
project_id="p1",
|
||||
char_count=4280,
|
||||
)
|
||||
block = render_scope_block(scope)
|
||||
assert "Meeting 14 May" in block
|
||||
assert "4280" in block or "4,280" in block
|
||||
|
||||
|
||||
def test_parses_camelcase_payload_from_renderer():
|
||||
payload = {
|
||||
"page": "project",
|
||||
"entityType": "project",
|
||||
"entityId": "p1",
|
||||
"entityName": "Acme",
|
||||
"counts": {"tasks": 5, "notes": 1, "milestones": 2},
|
||||
}
|
||||
scope = ContextualScope.model_validate(payload)
|
||||
assert scope.entity_id == "p1"
|
||||
assert scope.entity_name == "Acme"
|
||||
44
tests/test_contextual_ws.py
Normal file
44
tests/test_contextual_ws.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Tests for contextual WS frame handlers.
|
||||
|
||||
These tests only exercise the new handler functions in device_ws.py and do
|
||||
not depend on litellm or the full deep_agent import chain. They monkeypatch
|
||||
run_contextual_stream so no LLM call is made.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_contextual_scope_update_appends_system_message_no_llm(monkeypatch):
|
||||
"""_handle_contextual_scope_update must:
|
||||
- call append_system_message on the session buffer
|
||||
- send a contextual_scope_ack back on the socket
|
||||
- make no LLM call
|
||||
"""
|
||||
from app.api.routes import device_ws
|
||||
|
||||
ws = AsyncMock()
|
||||
buffer = MagicMock()
|
||||
buffer.append_system_message = MagicMock()
|
||||
|
||||
payload = {
|
||||
"type": "contextual_scope_update",
|
||||
"session_id": "s1",
|
||||
"scope": {
|
||||
"page": "project",
|
||||
"entityType": "project",
|
||||
"entityId": "p1",
|
||||
"entityName": "Acme",
|
||||
"counts": {"tasks": 1, "notes": 0, "milestones": 0},
|
||||
},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(device_ws, "get_session_buffer", lambda *a, **kw: buffer)
|
||||
await device_ws._handle_contextual_scope_update(ws, "user1", payload)
|
||||
|
||||
ws.send_text.assert_awaited_once()
|
||||
import json
|
||||
sent = json.loads(ws.send_text.await_args.args[0])
|
||||
assert sent["type"] == "contextual_scope_ack"
|
||||
assert sent["session_id"] == "s1"
|
||||
buffer.append_system_message.assert_called_once()
|
||||
@@ -10,10 +10,10 @@ import pytest
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
from app.core.deep_agent import (
|
||||
_infer_floating_domain,
|
||||
_build_system_prompt,
|
||||
_datetime_context_injection,
|
||||
_normalize_tagged_list_lines,
|
||||
run_floating,
|
||||
run_floating_stream,
|
||||
_request_context_block,
|
||||
run_home,
|
||||
)
|
||||
|
||||
@@ -63,7 +63,7 @@ class _FakeLLM:
|
||||
async def test_run_home_uses_mocked_tool_result():
|
||||
fake_llm = _FakeLLM()
|
||||
|
||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
||||
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||
):
|
||||
out = await run_home("user-1", "list my tasks", {})
|
||||
@@ -72,53 +72,6 @@ async def test_run_home_uses_mocked_tool_result():
|
||||
assert "Mock Task" in out
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
|
||||
fake_llm = _FakeLLM()
|
||||
|
||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||
):
|
||||
events = []
|
||||
async for event in run_floating_stream(
|
||||
"user-1",
|
||||
"show me timeline updates",
|
||||
{"scope": {"type": "timeline", "id": "tl-1"}},
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert events[0] == (
|
||||
"floating_domain",
|
||||
{"type": "timeline", "id": "tl-1", "section": None},
|
||||
)
|
||||
assert ("token", "stream-") in events
|
||||
assert ("token", "ok") in events
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
|
||||
class _ClassifierOnlyLLM:
|
||||
async def ainvoke(self, _messages):
|
||||
return AIMessage(
|
||||
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
|
||||
)
|
||||
|
||||
with patch("app.core.deep_agent.get_llm", return_value=_ClassifierOnlyLLM()):
|
||||
domain = await _infer_floating_domain(
|
||||
"Quali sono i miei task per il progetto X",
|
||||
{
|
||||
"scope": {"type": "timeline"},
|
||||
"resolved_project_id": "213213-312321-312312-421321",
|
||||
},
|
||||
)
|
||||
|
||||
assert domain == {
|
||||
"type": "project",
|
||||
"id": "213213-312321-312312-421321",
|
||||
"section": "task",
|
||||
}
|
||||
|
||||
|
||||
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
|
||||
raw = (
|
||||
"Certo!\n\n"
|
||||
@@ -155,134 +108,211 @@ def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_
|
||||
assert "<timeline>[tl-future]</timeline>" not in out
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_floating_strips_xml_like_tags_from_final_text():
|
||||
fake_llm = _FakeLLM()
|
||||
# ── _datetime_context_injection ────────────────────────────────────────────────
|
||||
|
||||
async def _fake_run_single_agent(**_kwargs):
|
||||
return (
|
||||
"Hai 1 task:\\n"
|
||||
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||
)
|
||||
|
||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
||||
):
|
||||
text, _domain = await run_floating(
|
||||
"user-1",
|
||||
"quali task ho?",
|
||||
{"scope": {"type": "task"}},
|
||||
)
|
||||
|
||||
assert "<task>" not in text
|
||||
assert "</task>" not in text
|
||||
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text
|
||||
def _fp(tz: str, now_iso: str) -> dict:
|
||||
return {"timezone": tz, "now_iso": now_iso, "date_format": "dd/MM/yyyy", "time_format": "24h"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
|
||||
fake_llm = _FakeLLM()
|
||||
|
||||
async def _fake_stream(**_kwargs):
|
||||
yield "token", "Hai 1 task:\\n"
|
||||
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||
|
||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
||||
):
|
||||
events = []
|
||||
async for event in run_floating_stream(
|
||||
"user-1",
|
||||
"quali task ho?",
|
||||
{"scope": {"type": "task"}},
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
token_events = [str(data) for event_type, data in events if event_type == "token"]
|
||||
combined = "".join(token_events)
|
||||
assert "<task>" not in combined
|
||||
assert "</task>" not in combined
|
||||
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined
|
||||
def _parse_ms(block: str, key: str) -> tuple[int, int]:
|
||||
"""Extract [start, end] from a 'key [start, end]' line in the DATE CONTEXT block."""
|
||||
import re
|
||||
m = re.search(rf"^{key}\s+\[(\d+),\s*(\d+)\]", block, re.MULTILINE)
|
||||
assert m, f"Key '{key}' not found in block:\n{block}"
|
||||
return int(m.group(1)), int(m.group(2))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty():
|
||||
class _NoChunkLLM:
|
||||
def __init__(self) -> None:
|
||||
self.calls = 0
|
||||
def test_datetime_context_injection_europe_rome_late_evening():
|
||||
"""22:16 CEST on 2026-04-26 — 'tomorrow' must be 2026-04-27 00:00→23:59:59.999 CEST."""
|
||||
from zoneinfo import ZoneInfo
|
||||
from datetime import datetime, timezone
|
||||
|
||||
def bind_tools(self, _tools):
|
||||
block = _datetime_context_injection({"format_prefs": _fp("Europe/Rome", "2026-04-26T20:16:02.155Z")})
|
||||
assert "DATE CONTEXT" in block
|
||||
assert "Europe/Rome" in block
|
||||
|
||||
tz = ZoneInfo("Europe/Rome")
|
||||
today_start = int(datetime(2026, 4, 26, tzinfo=tz).timestamp() * 1000)
|
||||
today_end = int(datetime(2026, 4, 27, tzinfo=tz).timestamp() * 1000) - 1
|
||||
tomorrow_start = today_end + 1
|
||||
tomorrow_end = int(datetime(2026, 4, 28, tzinfo=tz).timestamp() * 1000) - 1
|
||||
|
||||
t_s, t_e = _parse_ms(block, "today")
|
||||
assert t_s == today_start
|
||||
assert t_e == today_end
|
||||
|
||||
tm_s, tm_e = _parse_ms(block, "tomorrow")
|
||||
assert tm_s == tomorrow_start
|
||||
assert tm_e == tomorrow_end
|
||||
|
||||
# Sanity: window is exactly 86 400 000 ms (1 day, CEST has no DST jump on this date)
|
||||
assert today_end - today_start + 1 == 86_400_000
|
||||
assert tomorrow_end - tomorrow_start + 1 == 86_400_000
|
||||
|
||||
|
||||
def test_datetime_context_injection_utc():
|
||||
"""UTC timezone: boundaries are clean UTC midnights."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
block = _datetime_context_injection({"format_prefs": _fp("UTC", "2026-01-15T10:00:00Z")})
|
||||
t_s, t_e = _parse_ms(block, "today")
|
||||
expected_start = int(datetime(2026, 1, 15, tzinfo=timezone.utc).timestamp() * 1000)
|
||||
assert t_s == expected_start
|
||||
assert t_e == expected_start + 86_400_000 - 1
|
||||
|
||||
|
||||
def test_datetime_context_injection_dst_spring_forward():
|
||||
"""Europe/Rome DST spring-forward 2026-03-29: that day is 23h, not 24h."""
|
||||
from zoneinfo import ZoneInfo
|
||||
from datetime import datetime
|
||||
|
||||
block = _datetime_context_injection({"format_prefs": _fp("Europe/Rome", "2026-03-29T08:00:00Z")})
|
||||
tz = ZoneInfo("Europe/Rome")
|
||||
day_start = int(datetime(2026, 3, 29, tzinfo=tz).timestamp() * 1000)
|
||||
day_end = int(datetime(2026, 3, 30, tzinfo=tz).timestamp() * 1000) - 1
|
||||
|
||||
t_s, t_e = _parse_ms(block, "today")
|
||||
assert t_s == day_start
|
||||
assert t_e == day_end
|
||||
assert t_e - t_s + 1 == 23 * 3_600_000 # 23-hour day
|
||||
|
||||
|
||||
def test_datetime_context_injection_dst_fall_back():
|
||||
"""Europe/Rome DST fall-back 2026-10-25: that day is 25h."""
|
||||
from zoneinfo import ZoneInfo
|
||||
from datetime import datetime
|
||||
|
||||
block = _datetime_context_injection({"format_prefs": _fp("Europe/Rome", "2026-10-25T08:00:00Z")})
|
||||
tz = ZoneInfo("Europe/Rome")
|
||||
day_start = int(datetime(2026, 10, 25, tzinfo=tz).timestamp() * 1000)
|
||||
day_end = int(datetime(2026, 10, 26, tzinfo=tz).timestamp() * 1000) - 1
|
||||
|
||||
t_s, t_e = _parse_ms(block, "today")
|
||||
assert t_s == day_start
|
||||
assert t_e == day_end
|
||||
assert t_e - t_s + 1 == 25 * 3_600_000 # 25-hour day
|
||||
|
||||
|
||||
def test_datetime_context_injection_year_boundary():
|
||||
"""Dec 31 → Jan 1: last_year, this_year, next_month cross year boundary correctly."""
|
||||
from zoneinfo import ZoneInfo
|
||||
from datetime import datetime
|
||||
|
||||
block = _datetime_context_injection({"format_prefs": _fp("UTC", "2026-12-31T23:00:00Z")})
|
||||
tz = ZoneInfo("UTC")
|
||||
|
||||
yr_s, yr_e = _parse_ms(block, "this_year")
|
||||
assert yr_s == int(datetime(2026, 1, 1, tzinfo=tz).timestamp() * 1000)
|
||||
assert yr_e == int(datetime(2027, 1, 1, tzinfo=tz).timestamp() * 1000) - 1
|
||||
|
||||
ly_s, ly_e = _parse_ms(block, "last_year")
|
||||
assert ly_s == int(datetime(2025, 1, 1, tzinfo=tz).timestamp() * 1000)
|
||||
assert ly_e == yr_s - 1
|
||||
|
||||
nm_s, _ = _parse_ms(block, "next_month")
|
||||
assert nm_s == int(datetime(2027, 1, 1, tzinfo=tz).timestamp() * 1000)
|
||||
|
||||
|
||||
def test_datetime_context_injection_missing_format_prefs():
|
||||
assert _datetime_context_injection({}) == ""
|
||||
assert _datetime_context_injection({"format_prefs": None}) == ""
|
||||
assert _datetime_context_injection({"format_prefs": "bad"}) == ""
|
||||
|
||||
|
||||
# ── _request_context_block ─────────────────────────────────────────────────────
|
||||
|
||||
def test_request_context_block_scope_and_project():
|
||||
ctx = {"scope": {"type": "task", "id": "t-1"}, "resolved_project_id": "proj-uuid"}
|
||||
block = _request_context_block(ctx)
|
||||
assert "scope" in block
|
||||
assert "resolved_project_id: proj-uuid" in block
|
||||
|
||||
|
||||
def test_request_context_block_empty():
|
||||
assert _request_context_block({}) == ""
|
||||
assert _request_context_block({"scope": None}) == ""
|
||||
|
||||
|
||||
# ── _build_system_prompt ───────────────────────────────────────────────────────
|
||||
|
||||
def test_build_system_prompt_substitutes_all_slots(monkeypatch):
|
||||
"""All five slots must appear in the compiled output; no raw placeholder remains."""
|
||||
# Patch get_prompt_or_fallback to return None prompt_obj so we use fallback .format() path
|
||||
import app.core.deep_agent as da
|
||||
monkeypatch.setattr(da, "get_prompt_or_fallback", lambda name, fallback: (fallback, None))
|
||||
|
||||
ctx = {
|
||||
"format_prefs": _fp("Europe/Rome", "2026-04-26T20:16:02.155Z"),
|
||||
"core_memory": {"language": "it"},
|
||||
"relational_memory": ["Alice — client"],
|
||||
"proactive_hints": ["User prefers morning meetings"],
|
||||
"scope": {"type": "task"},
|
||||
"resolved_project_id": "proj-1",
|
||||
}
|
||||
from app.core.deep_agent import _HOME_SYSTEM_PROMPT
|
||||
text, _ = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, ctx)
|
||||
|
||||
# No unresolved placeholders
|
||||
assert "{date_context}" not in text
|
||||
assert "{language_instruction}" not in text
|
||||
assert "{relational_memory}" not in text
|
||||
assert "{proactive_hints}" not in text
|
||||
assert "{request_context}" not in text
|
||||
|
||||
# Content was injected
|
||||
assert "DATE CONTEXT" in text
|
||||
assert "Italian" in text
|
||||
assert "Alice" in text
|
||||
assert "morning meetings" in text
|
||||
assert "proj-1" in text
|
||||
|
||||
|
||||
def test_build_system_prompt_empty_format_prefs(monkeypatch):
|
||||
"""Missing format_prefs must not raise — date_context slot renders empty string."""
|
||||
import app.core.deep_agent as da
|
||||
monkeypatch.setattr(da, "get_prompt_or_fallback", lambda name, fallback: (fallback, None))
|
||||
|
||||
from app.core.deep_agent import _HOME_SYSTEM_PROMPT
|
||||
text, _ = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, {})
|
||||
# Prompt renders without error; date section is empty but structure holds
|
||||
assert "# Date filtering" in text
|
||||
assert "{date_context}" not in text
|
||||
|
||||
|
||||
def test_human_message_is_bare_message(monkeypatch):
|
||||
"""After the refactor HumanMessage content must equal the raw user message exactly."""
|
||||
import app.core.deep_agent as da
|
||||
from langchain_core.messages import HumanMessage as LCHumanMessage
|
||||
|
||||
captured: list[list] = []
|
||||
|
||||
class _CaptureLLM:
|
||||
def bind_tools(self, _):
|
||||
return self
|
||||
|
||||
async def ainvoke(self, _messages):
|
||||
self.calls += 1
|
||||
if self.calls == 1:
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call-1",
|
||||
"name": "list_tasks",
|
||||
"args": {},
|
||||
}
|
||||
],
|
||||
)
|
||||
return AIMessage(content="No notes found.")
|
||||
async def ainvoke(self, messages):
|
||||
captured.append(list(messages))
|
||||
return AIMessage(content="risposta")
|
||||
|
||||
async def astream(self, _messages):
|
||||
if False:
|
||||
yield None
|
||||
monkeypatch.setattr(da, "get_prompt_or_fallback", lambda n, f: (f, None))
|
||||
monkeypatch.setattr(da, "get_agent_llm", lambda _: _CaptureLLM())
|
||||
monkeypatch.setattr(da, "_all_tools_for_user", lambda *_: [])
|
||||
monkeypatch.setattr(da, "get_langfuse", lambda: None)
|
||||
monkeypatch.setattr(da, "set_tool_result_collector", lambda _: None)
|
||||
monkeypatch.setattr(da, "clear_tool_result_collector", lambda: None)
|
||||
|
||||
with patch("app.core.deep_agent.get_llm", return_value=_NoChunkLLM()), patch(
|
||||
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||
):
|
||||
events = []
|
||||
async for event in run_floating_stream(
|
||||
"user-1",
|
||||
"quali sono le note?",
|
||||
{"scope": {"type": "note"}},
|
||||
):
|
||||
events.append(event)
|
||||
import asyncio
|
||||
|
||||
assert events[0][0] == "floating_domain"
|
||||
assert ("token", "No notes found.") in events
|
||||
async def _run():
|
||||
chunks = []
|
||||
ctx = {"format_prefs": _fp("UTC", "2026-04-27T10:00:00Z")}
|
||||
async for ev in da.run_home_stream("u1", "Cosa devo fare domani?", ctx):
|
||||
chunks.append(ev)
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(_run())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_floating_returns_fallback_when_sanitization_would_empty_text():
|
||||
fake_llm = _FakeLLM()
|
||||
|
||||
async def _fake_run_single_agent(**_kwargs):
|
||||
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||
|
||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
||||
):
|
||||
text, _domain = await run_floating(
|
||||
"user-1",
|
||||
"quali task ho?",
|
||||
{"scope": {"type": "task"}},
|
||||
)
|
||||
|
||||
assert text == "No results found."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text():
|
||||
fake_llm = _FakeLLM()
|
||||
|
||||
async def _fake_stream(**_kwargs):
|
||||
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||
|
||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
||||
):
|
||||
events = []
|
||||
async for event in run_floating_stream(
|
||||
"user-1",
|
||||
"quali task ho?",
|
||||
{"scope": {"type": "task"}},
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert ("token", "No results found.") in events
|
||||
assert captured, "LLM was never called"
|
||||
messages = captured[0]
|
||||
human = next(m for m in messages if isinstance(m, LCHumanMessage))
|
||||
assert human.content == "Cosa devo fare domani?"
|
||||
assert "Context:" not in human.content
|
||||
|
||||
@@ -156,40 +156,6 @@ async def test_manager_unregister_cancels_pending_calls(manager, mock_ws):
|
||||
assert fut.cancelled()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_agent_data_queue(manager, mock_ws):
|
||||
manager.register("user1", "dev-A", mock_ws)
|
||||
q = manager.get_agent_data_queue("user1", "run-xyz")
|
||||
# Put a frame and get it back.
|
||||
frame = {"type": "agent_data", "run_id": "run-xyz", "files": []}
|
||||
await q.put(frame)
|
||||
assert await q.get() == frame
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_agent_data_queue_creates_once(manager, mock_ws):
|
||||
manager.register("user1", "dev-A", mock_ws)
|
||||
q1 = manager.get_agent_data_queue("user1", "run-1")
|
||||
q2 = manager.get_agent_data_queue("user1", "run-1")
|
||||
assert q1 is q2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_agent_data_queue_raises_when_offline(manager):
|
||||
with pytest.raises(RuntimeError, match="not connected"):
|
||||
manager.get_agent_data_queue("ghost", "run-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_cleanup_agent_data_queue(manager, mock_ws):
|
||||
manager.register("user1", "dev-A", mock_ws)
|
||||
manager.get_agent_data_queue("user1", "run-1")
|
||||
manager.cleanup_agent_data_queue("user1", "run-1")
|
||||
# After cleanup a new queue is created (not the same object).
|
||||
q_new = manager.get_agent_data_queue("user1", "run-1")
|
||||
assert q_new is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests — /api/v1/ws/device endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -235,7 +201,6 @@ def test_ws_device_invalid_first_frame_closes(client):
|
||||
def test_ws_device_tool_result_dispatched(client):
|
||||
"""tool_result frame is routed to the DeviceConnectionManager."""
|
||||
token = make_jwt(tier="free")
|
||||
user_id = TEST_USER_IDS["free"]
|
||||
|
||||
from app.core.device_manager import device_manager as dm
|
||||
|
||||
@@ -266,43 +231,6 @@ def test_ws_device_tool_result_dispatched(client):
|
||||
assert any(c["call_id"] == "call-123" for c in captured)
|
||||
|
||||
|
||||
def test_ws_device_agent_data_enqueued(client):
|
||||
"""agent_data frame is placed in the per-run queue by the message loop."""
|
||||
from app.core.device_manager import device_manager as dm
|
||||
|
||||
token = make_jwt(tier="free")
|
||||
user_id = TEST_USER_IDS["free"]
|
||||
|
||||
# Capture the queue object the message loop accesses.
|
||||
captured_queue: list[asyncio.Queue] = []
|
||||
original_get_queue = dm.get_agent_data_queue
|
||||
|
||||
def _spy_get_queue(uid, run_id):
|
||||
q = original_get_queue(uid, run_id)
|
||||
if not captured_queue:
|
||||
captured_queue.append(q)
|
||||
return q
|
||||
|
||||
with patch.object(dm, "get_agent_data_queue", side_effect=_spy_get_queue):
|
||||
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
|
||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||
ws.send_text(_device_hello("dev-001"))
|
||||
ws.send_text(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "agent_data",
|
||||
"run_id": "run-XYZ",
|
||||
"files": [{"path": "/tmp/file.txt", "content": "hello"}],
|
||||
}
|
||||
)
|
||||
)
|
||||
ws.close()
|
||||
|
||||
# The queue should have received exactly one frame.
|
||||
assert captured_queue, "queue was never accessed"
|
||||
assert not captured_queue[0].empty()
|
||||
|
||||
|
||||
def test_ws_device_disconnect_marks_run_logs_as_error(client, db_session):
|
||||
"""On disconnect, _mark_runs_disconnected is called with the correct user_id."""
|
||||
from app.api.routes import device_ws as _dws
|
||||
|
||||
139
tests/test_folder_agent_tool.py
Normal file
139
tests/test_folder_agent_tool.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.agents.folder_agent import (
|
||||
read_project_folder_file,
|
||||
search_project_folder_file,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def test_happy_path():
|
||||
with patch(
|
||||
"app.agents.folder_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"content": "file body", "kind": "text", "totalSize": 9}),
|
||||
):
|
||||
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "docs/x.md"})
|
||||
assert "file body" in out
|
||||
assert "kind=text" in out
|
||||
|
||||
|
||||
async def test_traversal_rejected():
|
||||
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "../../etc/passwd"})
|
||||
assert out == "Access denied"
|
||||
|
||||
|
||||
async def test_absolute_path_rejected():
|
||||
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "C:\\Windows\\foo"})
|
||||
assert out == "Access denied"
|
||||
|
||||
|
||||
async def test_missing_file():
|
||||
with patch(
|
||||
"app.agents.folder_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"content": "", "kind": "missing", "totalSize": 0}),
|
||||
):
|
||||
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "ghost.md"})
|
||||
assert "not found" in out.lower()
|
||||
|
||||
|
||||
async def test_pagination_signals_more_available():
|
||||
# Electron returned the first slice, totalSize larger than slice length.
|
||||
with patch(
|
||||
"app.agents.folder_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"content": "first chunk", "kind": "text", "totalSize": 1000}),
|
||||
):
|
||||
out = await read_project_folder_file.ainvoke({
|
||||
"project_id": "p1",
|
||||
"relative_path": "big.txt",
|
||||
"offset": 0,
|
||||
"length": 11,
|
||||
})
|
||||
assert "first chunk" in out
|
||||
assert "More content available" in out
|
||||
assert "offset=11" in out
|
||||
|
||||
|
||||
async def test_pdf_extracted_then_sliced(monkeypatch):
|
||||
from app.agents import folder_agent
|
||||
monkeypatch.setattr(folder_agent, "_extract_pdf_text", lambda b: "ABC " * 100)
|
||||
with patch(
|
||||
"app.agents.folder_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"content": "JVBERi0xLg==", "kind": "pdf", "totalSize": 12}),
|
||||
):
|
||||
out = await read_project_folder_file.ainvoke({
|
||||
"project_id": "p1",
|
||||
"relative_path": "doc.pdf",
|
||||
"offset": 0,
|
||||
"length": 8,
|
||||
})
|
||||
assert "kind=pdf" in out
|
||||
assert "ABC ABC " in out
|
||||
assert "More content available" in out
|
||||
|
||||
|
||||
async def test_image_returns_placeholder():
|
||||
with patch(
|
||||
"app.agents.folder_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"content": "iVBORw0K", "kind": "image", "totalSize": 1024}),
|
||||
):
|
||||
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "logo.png"})
|
||||
assert "image" in out.lower()
|
||||
|
||||
|
||||
async def test_search_finds_match_with_context():
|
||||
body = "alpha\nbeta\nthe needle is here\ngamma\ndelta"
|
||||
with patch(
|
||||
"app.agents.folder_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"content": body, "kind": "text", "totalSize": len(body)}),
|
||||
):
|
||||
out = await search_project_folder_file.ainvoke({
|
||||
"project_id": "p1",
|
||||
"relative_path": "log.txt",
|
||||
"query": "needle",
|
||||
"context_lines": 1,
|
||||
})
|
||||
assert "needle" in out
|
||||
assert "matches=1" in out
|
||||
# Context lines included
|
||||
assert "beta" in out
|
||||
assert "gamma" in out
|
||||
|
||||
|
||||
async def test_search_no_match():
|
||||
with patch(
|
||||
"app.agents.folder_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"content": "nothing here", "kind": "text", "totalSize": 12}),
|
||||
):
|
||||
out = await search_project_folder_file.ainvoke({
|
||||
"project_id": "p1",
|
||||
"relative_path": "x.txt",
|
||||
"query": "zzz",
|
||||
})
|
||||
assert "No matches" in out
|
||||
|
||||
|
||||
async def test_search_rejects_traversal():
|
||||
out = await search_project_folder_file.ainvoke({
|
||||
"project_id": "p1",
|
||||
"relative_path": "../etc/passwd",
|
||||
"query": "root",
|
||||
})
|
||||
assert out == "Access denied"
|
||||
|
||||
|
||||
async def test_search_image_rejected():
|
||||
with patch(
|
||||
"app.agents.folder_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"content": "b64data", "kind": "image", "totalSize": 100}),
|
||||
):
|
||||
out = await search_project_folder_file.ainvoke({
|
||||
"project_id": "p1",
|
||||
"relative_path": "logo.png",
|
||||
"query": "anything",
|
||||
})
|
||||
assert "Cannot search" in out
|
||||
83
tests/test_folder_indexer.py
Normal file
83
tests/test_folder_indexer.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Folder indexer LLM helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.folder_indexer import summarize_text, summarize_image, IndexResult
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def test_summarize_text_returns_summary_and_tokens():
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.content = "Kickoff notes covering scope and deadlines."
|
||||
mock_resp.usage_metadata = {"input_tokens": 320, "output_tokens": 18, "total_tokens": 338}
|
||||
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
|
||||
result = await summarize_text(content="hello world", ext=".md", name="kickoff.md")
|
||||
assert isinstance(result, IndexResult)
|
||||
assert result.summary == "Kickoff notes covering scope and deadlines."
|
||||
assert result.tokens_used == 338
|
||||
|
||||
|
||||
async def test_summarize_text_truncates_summary_at_500_chars():
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.content = "x" * 1000
|
||||
mock_resp.usage_metadata = {"total_tokens": 100}
|
||||
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
|
||||
result = await summarize_text(content="x", ext=".md", name="x.md")
|
||||
assert len(result.summary) <= 500
|
||||
|
||||
|
||||
async def test_summarize_image_uses_vision_content_blocks():
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.content = "Final logo on white background."
|
||||
mock_resp.usage_metadata = {"total_tokens": 500}
|
||||
captured = {}
|
||||
|
||||
async def fake_llm_vision(messages):
|
||||
captured["messages"] = messages
|
||||
return mock_resp
|
||||
|
||||
with patch("app.core.folder_indexer._llm_vision", new=fake_llm_vision):
|
||||
result = await summarize_image(image_b64="iVBORw0KG", mime="image/png")
|
||||
|
||||
assert "Final logo" in result.summary
|
||||
assert result.tokens_used == 500
|
||||
# last message contains an image content block
|
||||
last = captured["messages"][-1]
|
||||
assert any(
|
||||
isinstance(p, dict) and p.get("type") == "image_url"
|
||||
for p in (last.content if isinstance(last.content, list) else [])
|
||||
)
|
||||
|
||||
|
||||
async def test_summarize_pdf_extracts_then_summarizes(monkeypatch):
|
||||
# pypdf.PdfReader returns text from pages
|
||||
from app.core import folder_indexer
|
||||
class FakePage:
|
||||
def extract_text(self): return "PDF page content with project info."
|
||||
class FakeReader:
|
||||
pages = [FakePage(), FakePage()]
|
||||
monkeypatch.setattr(folder_indexer, "PdfReader", lambda buf: FakeReader())
|
||||
mock_resp = AsyncMock(); mock_resp.content = "Project info doc."; mock_resp.usage_metadata = {"total_tokens": 50}
|
||||
async def fake_llm(messages): return mock_resp
|
||||
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
|
||||
result = await folder_indexer.summarize_pdf(pdf_b64="SGVsbG8=", name="doc.pdf")
|
||||
assert "Project info" in result.summary
|
||||
assert result.tokens_used == 50
|
||||
|
||||
|
||||
async def test_summarize_docx_extracts_then_summarizes(monkeypatch):
|
||||
from app.core import folder_indexer
|
||||
class FakePara:
|
||||
def __init__(self, t): self.text = t
|
||||
class FakeDoc:
|
||||
paragraphs = [FakePara("Heading"), FakePara("Body paragraph one.")]
|
||||
monkeypatch.setattr(folder_indexer, "DocxDocument", lambda buf: FakeDoc())
|
||||
mock_resp = AsyncMock(); mock_resp.content = "Heading and body."; mock_resp.usage_metadata = {"total_tokens": 30}
|
||||
async def fake_llm(messages): return mock_resp
|
||||
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
|
||||
result = await folder_indexer.summarize_docx(docx_b64="UEsDBBQ=", name="doc.docx")
|
||||
assert result.summary == "Heading and body."
|
||||
94
tests/test_folder_quota.py
Normal file
94
tests/test_folder_quota.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Folder quota helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.billing.quota import (
|
||||
check_folder_quota,
|
||||
add_token_usage,
|
||||
QuotaExceeded,
|
||||
)
|
||||
from app.models import MonthlyTokenUsage
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
async def test_check_folder_quota_free_rejects_above_file_cap(db, test_user_free):
|
||||
with pytest.raises(QuotaExceeded) as exc:
|
||||
await check_folder_quota(
|
||||
user_id=test_user_free.id, tier="free", estimated_files=500, db=db
|
||||
)
|
||||
assert exc.value.reason == "max_files"
|
||||
|
||||
|
||||
async def test_check_folder_quota_free_passes_under_cap(db, test_user_free):
|
||||
# No raise
|
||||
await check_folder_quota(
|
||||
user_id=test_user_free.id, tier="free", estimated_files=50, db=db
|
||||
)
|
||||
|
||||
|
||||
async def test_check_folder_quota_rejects_when_monthly_exhausted(db, test_user_free):
|
||||
ym = datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
db.add(MonthlyTokenUsage(
|
||||
user_id=test_user_free.id, year_month=ym, feature="folder_index", tokens_used=100_000
|
||||
))
|
||||
await db.commit()
|
||||
with pytest.raises(QuotaExceeded) as exc:
|
||||
await check_folder_quota(
|
||||
user_id=test_user_free.id, tier="free", estimated_files=10, db=db
|
||||
)
|
||||
assert exc.value.reason == "monthly_tokens"
|
||||
|
||||
|
||||
async def test_check_folder_quota_power_unlimited(db, test_user_power):
|
||||
await check_folder_quota(
|
||||
user_id=test_user_power.id, tier="power", estimated_files=999_999, db=db
|
||||
)
|
||||
|
||||
|
||||
async def test_add_token_usage_atomic_increment(db, test_user_free):
|
||||
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=1500, db=db)
|
||||
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=2500, db=db)
|
||||
ym = datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
row = (await db.execute(
|
||||
select(MonthlyTokenUsage).where(
|
||||
MonthlyTokenUsage.user_id == test_user_free.id,
|
||||
MonthlyTokenUsage.year_month == ym,
|
||||
MonthlyTokenUsage.feature == "folder_index",
|
||||
)
|
||||
)).scalar_one()
|
||||
assert row.tokens_used == 4000
|
||||
|
||||
|
||||
async def test_add_token_usage_returns_exhausted_when_over_cap(db, test_user_free):
|
||||
result = await add_token_usage(
|
||||
user_id=test_user_free.id, feature="folder_index", tokens=150_000, db=db, cap=100_000
|
||||
)
|
||||
assert result.exhausted is True
|
||||
assert result.tokens_used == 150_000
|
||||
|
||||
|
||||
def test_quota_check_endpoint_rejects(client, auth_headers_free):
|
||||
res = client.post(
|
||||
"/api/v1/billing/quota/check",
|
||||
json={"feature": "folder_index", "estimated_files": 500},
|
||||
headers=auth_headers_free,
|
||||
)
|
||||
assert res.status_code == 402
|
||||
body = res.json()
|
||||
assert body["detail"]["reason"] == "max_files"
|
||||
|
||||
|
||||
def test_quota_check_endpoint_passes(client, auth_headers_free):
|
||||
res = client.post(
|
||||
"/api/v1/billing/quota/check",
|
||||
json={"feature": "folder_index", "estimated_files": 50},
|
||||
headers=auth_headers_free,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"ok": True}
|
||||
@@ -328,7 +328,7 @@ def _make_gmail_message(
|
||||
class TestGmailClientFetchMessages:
|
||||
"""GmailClient.fetch_messages tests with mocked Google API."""
|
||||
|
||||
def _make_client(self) -> "GmailClient":
|
||||
def _make_client(self):
|
||||
from app.integrations.gmail import GmailClient
|
||||
return GmailClient(_TOKEN_DICT)
|
||||
|
||||
@@ -509,7 +509,7 @@ def _make_graph_teams_message(
|
||||
class TestMSGraphClientFetchEmails:
|
||||
"""MSGraphClient.fetch_emails tests with mocked httpx."""
|
||||
|
||||
def _make_client(self) -> "MSGraphClient":
|
||||
def _make_client(self):
|
||||
from app.integrations.ms_graph import MSGraphClient
|
||||
return MSGraphClient(_MS_TOKEN_DICT)
|
||||
|
||||
@@ -608,7 +608,7 @@ class TestMSGraphClientFetchEmails:
|
||||
class TestMSGraphClientFetchMessages:
|
||||
"""MSGraphClient.fetch_messages (Teams) tests."""
|
||||
|
||||
def _make_client(self) -> "MSGraphClient":
|
||||
def _make_client(self):
|
||||
from app.integrations.ms_graph import MSGraphClient
|
||||
return MSGraphClient(_MS_TOKEN_DICT)
|
||||
|
||||
|
||||
69
tests/test_manifest_injection.py
Normal file
69
tests/test_manifest_injection.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.deep_agent import format_folder_manifest, MANIFEST_TOKEN_BUDGET
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
def test_format_folder_manifest_basic():
|
||||
manifest = {
|
||||
"folderPath": "D:\\Acme",
|
||||
"lastScannedAt": "2h ago",
|
||||
"files": [
|
||||
{"relPath": "briefs/kickoff.md", "kind": "text", "summary": "Kickoff notes; scope and deadlines."},
|
||||
{"relPath": "logos/logo-v3.png", "kind": "image", "summary": "Final logo on white."},
|
||||
],
|
||||
}
|
||||
out = format_folder_manifest(manifest)
|
||||
assert "<linked_folder>" in out
|
||||
assert "/briefs/kickoff.md" in out or "briefs/kickoff.md" in out
|
||||
assert "[text]" in out
|
||||
assert "[image]" in out
|
||||
|
||||
|
||||
def test_format_folder_manifest_truncates_past_budget():
|
||||
files = [
|
||||
{"relPath": f"f{i}.md", "kind": "text", "summary": "x" * 100, "mtimeMs": i}
|
||||
for i in range(2000)
|
||||
]
|
||||
out = format_folder_manifest({"folderPath": "p", "lastScannedAt": "now", "files": files})
|
||||
assert "more files omitted" in out
|
||||
# Rough token check
|
||||
assert len(out) // 4 < MANIFEST_TOKEN_BUDGET + 200
|
||||
|
||||
|
||||
def test_format_folder_manifest_null_returns_empty():
|
||||
assert format_folder_manifest(None) == ""
|
||||
assert format_folder_manifest({"files": []}) == ""
|
||||
|
||||
|
||||
async def test_brief_multi_project_manifest_top_5_per_project():
|
||||
fake_response = [
|
||||
{
|
||||
"projectId": "p1", "projectName": "Acme", "folderPath": "/a",
|
||||
"lastScannedAt": "now",
|
||||
"files": [
|
||||
{"relPath": f"f{i}.md", "kind": "text", "summary": "s", "mtimeMs": i}
|
||||
for i in range(10)
|
||||
],
|
||||
},
|
||||
{
|
||||
"projectId": "p2", "projectName": "Beta", "folderPath": "/b",
|
||||
"lastScannedAt": "now",
|
||||
"files": [{"relPath": "x.md", "kind": "text", "summary": "s", "mtimeMs": 1}],
|
||||
},
|
||||
]
|
||||
with patch(
|
||||
"app.core.deep_agent.execute_on_client",
|
||||
new=AsyncMock(return_value={"projects": fake_response}),
|
||||
):
|
||||
from app.core.deep_agent import build_brief_multi_project_manifest
|
||||
out = await build_brief_multi_project_manifest()
|
||||
# Project 1 has 10 files, only top 5 by mtimeMs should appear
|
||||
assert out.count("[p1]") <= 5
|
||||
# Project 2 has 1 file, must appear
|
||||
assert "[p2]" in out or "Beta" in out
|
||||
405
tests/test_memory_audit.py
Normal file
405
tests/test_memory_audit.py
Normal file
@@ -0,0 +1,405 @@
|
||||
"""Tests for Phase 7 — weekly audit_memory job.
|
||||
|
||||
Coverage:
|
||||
1. audit_memory never raises even if inner work fails.
|
||||
2. _scan_associative_contradictions skips when < 2 decryptable facts.
|
||||
3. _scan_associative_contradictions calls LLM and deletes flagged rows.
|
||||
4. _scan_associative_contradictions is a no-op when LLM fails.
|
||||
5. _scan_associative_contradictions is a no-op when LLM returns non-list.
|
||||
6. _canonicalize_relation_labels skips when no relation rows.
|
||||
7. _canonicalize_relation_labels rewrites variant labels to canonical form.
|
||||
8. _canonicalize_relation_labels is a no-op when LLM fails.
|
||||
9. _canonicalize_relation_labels is a no-op when remap is empty.
|
||||
10. Both helpers work correctly when Langfuse is unavailable (lf=None).
|
||||
11. get_prompt_or_fallback called with correct Langfuse prompt names.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from contextlib import contextmanager, ExitStack
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.memory_maintenance import (
|
||||
_canonicalize_relation_labels,
|
||||
_scan_associative_contradictions,
|
||||
audit_memory,
|
||||
)
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import MemoryAssociative, MemoryRelation, User
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
PRO_USER_ID = TEST_USER_IDS["pro"]
|
||||
_FERNET_KEY = Fernet.generate_key().decode()
|
||||
_FERNET = Fernet(_FERNET_KEY.encode())
|
||||
|
||||
|
||||
# ── DB override ───────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _override_db(db_session):
|
||||
async def _gen():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _gen
|
||||
yield
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def pro_user(db_session):
|
||||
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = _FERNET_KEY
|
||||
await db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def _enc(text: str) -> str:
|
||||
return _FERNET.encrypt(text.encode()).decode()
|
||||
|
||||
|
||||
def _assoc_row(user_id: str, text: str) -> MemoryAssociative:
|
||||
return MemoryAssociative(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
content_encrypted=_enc(text),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _relation_row(user_id: str, subject: str, predicate: str, obj: str) -> MemoryRelation:
|
||||
return MemoryRelation(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
subject_label=subject,
|
||||
subject_type="person",
|
||||
predicate=predicate,
|
||||
object_label=obj,
|
||||
object_type="company",
|
||||
confidence=0.8,
|
||||
)
|
||||
|
||||
|
||||
def _llm_response(content: str) -> MagicMock:
|
||||
msg = MagicMock()
|
||||
msg.content = content
|
||||
msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
return msg
|
||||
|
||||
|
||||
def _mock_llm(content: str) -> MagicMock:
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(return_value=_llm_response(content))
|
||||
return llm
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_audit(llm_mock, lf=None, prompt_text: str = "fallback {facts}"):
|
||||
"""Context manager that patches all external deps for audit helpers."""
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch("app.core.llm.get_agent_llm", return_value=llm_mock)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("app.core.llm.model_for_agent", return_value="memory-auditor")
|
||||
)
|
||||
stack.enter_context(
|
||||
patch("app.core.memory_maintenance.get_langfuse", return_value=lf)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"app.core.memory_maintenance.get_prompt_or_fallback",
|
||||
return_value=(prompt_text, None),
|
||||
)
|
||||
)
|
||||
stack.enter_context(
|
||||
patch(
|
||||
"app.core.memory_maintenance.compile_prompt",
|
||||
side_effect=lambda tmpl, obj, **kw: tmpl.format(**kw) if "{" in tmpl else tmpl,
|
||||
)
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
# ── Test 1: audit_memory never raises ────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_memory_never_raises_on_missing_user(db_session):
|
||||
"""audit_memory with a non-existent user_id must not raise."""
|
||||
await audit_memory(db_session, str(uuid.uuid4()))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_memory_never_raises_on_llm_failure(db_session, pro_user):
|
||||
"""audit_memory must swallow inner exceptions."""
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
|
||||
with (
|
||||
patch("app.core.llm.get_agent_llm", return_value=llm),
|
||||
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
|
||||
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
|
||||
patch(
|
||||
"app.core.memory_maintenance.get_prompt_or_fallback",
|
||||
return_value=("p {facts}", None),
|
||||
),
|
||||
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
|
||||
):
|
||||
await audit_memory(db_session, PRO_USER_ID)
|
||||
|
||||
|
||||
# ── Test 2: _scan skips when < 2 facts ───────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_contradictions_skips_with_one_fact(db_session, pro_user):
|
||||
row = _assoc_row(PRO_USER_ID, "Prefers morning meetings")
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(return_value=_llm_response("[]"))
|
||||
|
||||
with _patch_audit(llm):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
llm.ainvoke.assert_not_called()
|
||||
|
||||
|
||||
# ── Test 3: _scan deletes flagged contradiction ───────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_contradictions_deletes_flagged_row(db_session, pro_user):
|
||||
keep = _assoc_row(PRO_USER_ID, "Prefers morning meetings")
|
||||
drop = _assoc_row(PRO_USER_ID, "Never schedules before noon")
|
||||
db_session.add(keep)
|
||||
db_session.add(drop)
|
||||
await db_session.commit()
|
||||
|
||||
deletion_payload = json.dumps([{"delete": drop.id, "reason": "contradicts morning pref"}])
|
||||
llm = _mock_llm(deletion_payload)
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {facts}"):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
||||
)
|
||||
remaining = result.scalars().all()
|
||||
remaining_ids = {r.id for r in remaining}
|
||||
assert keep.id in remaining_ids
|
||||
assert drop.id not in remaining_ids
|
||||
|
||||
|
||||
# ── Test 4: _scan is no-op on LLM failure ────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_contradictions_noop_on_llm_failure(db_session, pro_user):
|
||||
for text in ("Fact A", "Fact B"):
|
||||
db_session.add(_assoc_row(PRO_USER_ID, text))
|
||||
await db_session.commit()
|
||||
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {facts}"):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
||||
)
|
||||
assert len(result.scalars().all()) == 2
|
||||
|
||||
|
||||
# ── Test 5: _scan is no-op when LLM returns non-list ─────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_contradictions_noop_on_non_list_response(db_session, pro_user):
|
||||
for text in ("Fact A", "Fact B"):
|
||||
db_session.add(_assoc_row(PRO_USER_ID, text))
|
||||
await db_session.commit()
|
||||
|
||||
llm = _mock_llm('"unexpected string"')
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {facts}"):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
||||
)
|
||||
assert len(result.scalars().all()) == 2
|
||||
|
||||
|
||||
# ── Test 6: _canonicalize skips when no relations ────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_skips_when_no_relations(db_session, pro_user):
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(return_value=_llm_response("[]"))
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {labels}"):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
llm.ainvoke.assert_not_called()
|
||||
|
||||
|
||||
# ── Test 7: _canonicalize rewrites variant labels ────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_rewrites_variant_labels(db_session, pro_user):
|
||||
row_a = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
|
||||
row_b = _relation_row(PRO_USER_ID, "Giulia R.", "reports_to", "Marco")
|
||||
row_c = _relation_row(PRO_USER_ID, "Marco", "manages", "Giulia")
|
||||
db_session.add(row_a)
|
||||
db_session.add(row_b)
|
||||
db_session.add(row_c)
|
||||
await db_session.commit()
|
||||
|
||||
groups = json.dumps([
|
||||
{"canonical": "Giulia", "variants": ["giulia", "Giulia R."]}
|
||||
])
|
||||
llm = _mock_llm(groups)
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {labels}"):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
await db_session.refresh(row_a)
|
||||
await db_session.refresh(row_b)
|
||||
await db_session.refresh(row_c)
|
||||
|
||||
assert row_a.subject_label == "Giulia"
|
||||
assert row_b.subject_label == "Giulia"
|
||||
assert row_c.object_label == "Giulia"
|
||||
assert row_c.subject_label == "Marco"
|
||||
|
||||
|
||||
# ── Test 8: _canonicalize is no-op on LLM failure ────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_noop_on_llm_failure(db_session, pro_user):
|
||||
row = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {labels}"):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
await db_session.refresh(row)
|
||||
assert row.subject_label == "giulia"
|
||||
|
||||
|
||||
# ── Test 9: _canonicalize is no-op when remap is empty ───────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_noop_when_remap_empty(db_session, pro_user):
|
||||
row = _relation_row(PRO_USER_ID, "Giulia", "works_at", "Acme")
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
llm = _mock_llm("[]")
|
||||
|
||||
with _patch_audit(llm, prompt_text="p {labels}"):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
await db_session.refresh(row)
|
||||
assert row.subject_label == "Giulia"
|
||||
|
||||
|
||||
# ── Test 10: both helpers work without Langfuse ───────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_works_without_langfuse(db_session, pro_user):
|
||||
keep = _assoc_row(PRO_USER_ID, "Prefers dark mode")
|
||||
drop = _assoc_row(PRO_USER_ID, "Prefers light mode")
|
||||
db_session.add(keep)
|
||||
db_session.add(drop)
|
||||
await db_session.commit()
|
||||
|
||||
deletion_payload = json.dumps([{"delete": drop.id, "reason": "contradicts dark mode"}])
|
||||
llm = _mock_llm(deletion_payload)
|
||||
|
||||
with _patch_audit(llm, lf=None, prompt_text="p {facts}"):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
|
||||
)
|
||||
remaining_ids = {r.id for r in result.scalars().all()}
|
||||
assert keep.id in remaining_ids
|
||||
assert drop.id not in remaining_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_works_without_langfuse(db_session, pro_user):
|
||||
row = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
|
||||
db_session.add(row)
|
||||
db_session.add(_relation_row(PRO_USER_ID, "Marco", "manages", "Giulia"))
|
||||
await db_session.commit()
|
||||
|
||||
groups = json.dumps([{"canonical": "Giulia", "variants": ["giulia"]}])
|
||||
llm = _mock_llm(groups)
|
||||
|
||||
with _patch_audit(llm, lf=None, prompt_text="p {labels}"):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
await db_session.refresh(row)
|
||||
assert row.subject_label == "Giulia"
|
||||
|
||||
|
||||
# ── Test 11: correct Langfuse prompt names used ───────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_uses_correct_langfuse_prompt_name(db_session, pro_user):
|
||||
for text in ("Fact A", "Fact B"):
|
||||
db_session.add(_assoc_row(PRO_USER_ID, text))
|
||||
await db_session.commit()
|
||||
|
||||
llm = _mock_llm("[]")
|
||||
mock_get_prompt = MagicMock(return_value=("p {facts}", None))
|
||||
|
||||
with (
|
||||
patch("app.core.llm.get_agent_llm", return_value=llm),
|
||||
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
|
||||
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_maintenance.get_prompt_or_fallback", mock_get_prompt),
|
||||
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
|
||||
):
|
||||
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
|
||||
|
||||
mock_get_prompt.assert_called_once()
|
||||
assert mock_get_prompt.call_args[0][0] == "memory_audit_contradictions"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_canonicalize_uses_correct_langfuse_prompt_name(db_session, pro_user):
|
||||
db_session.add(_relation_row(PRO_USER_ID, "Giulia", "works_at", "Acme"))
|
||||
db_session.add(_relation_row(PRO_USER_ID, "Marco", "manages", "Acme"))
|
||||
await db_session.commit()
|
||||
|
||||
llm = _mock_llm("[]")
|
||||
mock_get_prompt = MagicMock(return_value=("p {labels}", None))
|
||||
|
||||
with (
|
||||
patch("app.core.llm.get_agent_llm", return_value=llm),
|
||||
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
|
||||
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_maintenance.get_prompt_or_fallback", mock_get_prompt),
|
||||
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
|
||||
):
|
||||
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
|
||||
|
||||
mock_get_prompt.assert_called_once()
|
||||
assert mock_get_prompt.call_args[0][0] == "memory_audit_canonicalize"
|
||||
345
tests/test_memory_extraction.py
Normal file
345
tests/test_memory_extraction.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Tests for Phase 2 — Mem0-style Extract/Update pipeline.
|
||||
|
||||
Coverage:
|
||||
2.1 extract_candidates returns valid ExtractionResult with mocked LLM.
|
||||
2.2 decide_action — all 4 branches (ADD/UPDATE/DELETE/NOOP + empty existing).
|
||||
2.3 run_extraction end-to-end with mocked LLM writes expected rows.
|
||||
2.4 _dispatch_extraction — Pro user triggers realtime task; Free enqueues row.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.memory_extraction import (
|
||||
ExtractionResult,
|
||||
MemoryCandidate,
|
||||
decide_action,
|
||||
extract_candidates,
|
||||
run_extraction,
|
||||
)
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import ExtractionQueue, MemoryCore, User
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
|
||||
PRO_USER_ID = TEST_USER_IDS["pro"]
|
||||
FREE_USER_ID = TEST_USER_IDS["free"]
|
||||
_FERNET_KEY = Fernet.generate_key().decode()
|
||||
|
||||
|
||||
# ── DB override ───────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _override_db(db_session):
|
||||
async def _gen():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _gen
|
||||
yield
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def pro_user(db_session):
|
||||
"""Update the seeded pro user to have an encryption_key."""
|
||||
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = _FERNET_KEY
|
||||
await db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def free_user(db_session):
|
||||
"""Update the seeded free user to have an encryption_key."""
|
||||
result = await db_session.execute(select(User).where(User.id == FREE_USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = _FERNET_KEY
|
||||
await db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def _make_llm_response(content: str) -> MagicMock:
|
||||
msg = MagicMock()
|
||||
msg.content = content
|
||||
msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
return msg
|
||||
|
||||
|
||||
# ── TASK 2.1 — extract_candidates ────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_candidates_returns_valid_result():
|
||||
payload = {
|
||||
"candidates": [
|
||||
{
|
||||
"type": "fact",
|
||||
"content": "User's CFO is Giulia",
|
||||
"target_tier": "core",
|
||||
"subject": None,
|
||||
"predicate": None,
|
||||
"object": None,
|
||||
"confidence": 0.85,
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_response = _make_llm_response(json.dumps(payload))
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = (
|
||||
"system prompt {last_turn} {core_memory} {recent_episodes}",
|
||||
None,
|
||||
)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.bind.return_value = llm_instance
|
||||
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
result = await extract_candidates(
|
||||
last_turn="User: My CFO is Giulia\nAssistant: Noted.",
|
||||
core_memory={},
|
||||
recent_episodes=[],
|
||||
)
|
||||
|
||||
assert isinstance(result, ExtractionResult)
|
||||
assert len(result.candidates) == 1
|
||||
assert result.candidates[0].type == "fact"
|
||||
assert "Giulia" in result.candidates[0].content
|
||||
assert result.candidates[0].confidence == 0.85
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_candidates_returns_empty_on_llm_failure():
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = ("prompt {last_turn} {core_memory} {recent_episodes}", None)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.bind.return_value = llm_instance
|
||||
llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
result = await extract_candidates("turn", {}, [])
|
||||
|
||||
assert isinstance(result, ExtractionResult)
|
||||
assert result.candidates == []
|
||||
|
||||
|
||||
# ── TASK 2.2 — decide_action ─────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_action_add_when_no_existing():
|
||||
candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core")
|
||||
action = await decide_action(candidate, existing=[])
|
||||
assert action == "ADD"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_action_noop():
|
||||
candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core")
|
||||
mock_response = _make_llm_response("NOOP")
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
action = await decide_action(candidate, existing=["CFO is Giulia"])
|
||||
|
||||
assert action == "NOOP"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_action_update():
|
||||
candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core")
|
||||
mock_response = _make_llm_response("UPDATE")
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
action = await decide_action(candidate, existing=["CFO is Giulia"])
|
||||
|
||||
assert action == "UPDATE"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_action_delete():
|
||||
candidate = MemoryCandidate(type="fact", content="No longer have a CFO", target_tier="core")
|
||||
mock_response = _make_llm_response("DELETE")
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
action = await decide_action(candidate, existing=["CFO is Giulia"])
|
||||
|
||||
assert action == "DELETE"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_action_defaults_add_on_llm_failure():
|
||||
candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core")
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
|
||||
):
|
||||
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
action = await decide_action(candidate, existing=["old memory"])
|
||||
|
||||
assert action == "ADD"
|
||||
|
||||
|
||||
# ── TASK 2.3 — run_extraction end-to-end ─────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_extraction_writes_core_candidate(db_session, pro_user):
|
||||
"""'My CFO is Giulia' → fact candidate → core row written."""
|
||||
fact_payload = {
|
||||
"candidates": [
|
||||
{
|
||||
"type": "fact",
|
||||
"content": "User prefers morning meetings",
|
||||
"target_tier": "core",
|
||||
"confidence": 0.8,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
def _mock_llm_response(content: str):
|
||||
msg = MagicMock()
|
||||
msg.content = content
|
||||
msg.usage_metadata = {}
|
||||
return msg
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def _ainvoke_side_effect(messages):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# extract_candidates call
|
||||
return _mock_llm_response(json.dumps(fact_payload))
|
||||
# decide_action — no existing → short-circuits to ADD without LLM
|
||||
return _mock_llm_response("ADD")
|
||||
|
||||
with (
|
||||
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
|
||||
patch("app.core.memory_extraction.get_langfuse", return_value=None),
|
||||
patch(
|
||||
"app.core.memory_extraction.get_prompt_or_fallback",
|
||||
side_effect=lambda name, fb: (
|
||||
("p {last_turn} {core_memory} {recent_episodes}", None)
|
||||
if name == "memory_extraction"
|
||||
else ("p {candidate} {existing_memories}", None)
|
||||
),
|
||||
),
|
||||
):
|
||||
llm_instance = MagicMock()
|
||||
llm_instance.bind.return_value = llm_instance
|
||||
llm_instance.ainvoke = AsyncMock(side_effect=_ainvoke_side_effect)
|
||||
mock_get_llm.return_value = llm_instance
|
||||
|
||||
await run_extraction(
|
||||
db=db_session,
|
||||
user_id=PRO_USER_ID,
|
||||
last_user_msg="My CFO is Giulia",
|
||||
last_assistant_msg="Noted, I will remember that.",
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
# core row should exist
|
||||
result = await db_session.execute(
|
||||
select(MemoryCore).where(MemoryCore.user_id == PRO_USER_ID)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
assert len(rows) >= 1
|
||||
fernet = Fernet(_FERNET_KEY.encode())
|
||||
values = [fernet.decrypt(r.value_encrypted.encode()).decode() for r in rows]
|
||||
assert any("morning meetings" in v for v in values)
|
||||
|
||||
|
||||
# ── TASK 2.4 — dispatch ───────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_realtime_for_pro(db_session, pro_user):
|
||||
"""Pro user: asyncio.create_task called (not queue row)."""
|
||||
middleware = MemoryMiddleware(db_session)
|
||||
|
||||
with (
|
||||
patch("app.core.memory_middleware.asyncio.create_task") as mock_task,
|
||||
patch("app.billing.tier_manager.tier_manager.check_feature", return_value=True),
|
||||
):
|
||||
await middleware._dispatch_extraction(
|
||||
user_id=PRO_USER_ID,
|
||||
episode_id=str(uuid.uuid4()),
|
||||
last_user_msg="hello",
|
||||
last_assistant_msg="hi",
|
||||
session_id=None,
|
||||
)
|
||||
|
||||
mock_task.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_queue_for_free(db_session, free_user):
|
||||
"""Free user: ExtractionQueue row inserted."""
|
||||
middleware = MemoryMiddleware(db_session)
|
||||
ep_id = str(uuid.uuid4())
|
||||
|
||||
with patch("app.billing.tier_manager.tier_manager.check_feature", return_value=False):
|
||||
await middleware._dispatch_extraction(
|
||||
user_id=FREE_USER_ID,
|
||||
episode_id=ep_id,
|
||||
last_user_msg="hello",
|
||||
last_assistant_msg="hi",
|
||||
session_id=None,
|
||||
)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(ExtractionQueue).where(ExtractionQueue.user_id == FREE_USER_ID)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
assert len(rows) == 1
|
||||
assert rows[0].episode_id == ep_id
|
||||
@@ -12,13 +12,14 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.embeddings import embed_text
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
@@ -341,3 +342,33 @@ def test_home_request_calls_memory_middleware(client):
|
||||
stored_session_id, stored_message = store_calls[0][1], store_calls[0][2]
|
||||
assert stored_session_id == session_id
|
||||
assert stored_message == "Show tasks"
|
||||
|
||||
|
||||
# ── embed_text ─────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_text_returns_1536_floats():
|
||||
"""embed_text returns a 1536-dim float list when OpenAI responds successfully."""
|
||||
fake_embedding = [0.1] * 1536
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock(embedding=fake_embedding)]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("app.core.embeddings.AsyncOpenAI", return_value=mock_client):
|
||||
result = await embed_text("test text")
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 1536
|
||||
assert all(isinstance(x, float) for x in result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_text_returns_none_on_failure():
|
||||
"""embed_text returns None when OpenAI raises; must not propagate the exception."""
|
||||
with patch("app.core.embeddings.AsyncOpenAI", side_effect=Exception("no key")):
|
||||
result = await embed_text("test text")
|
||||
|
||||
assert result is None
|
||||
|
||||
153
tests/test_memory_proactive.py
Normal file
153
tests/test_memory_proactive.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Tests for Phase 5 — proactive hints surfacing.
|
||||
|
||||
Coverage:
|
||||
1. _proactive_hints_injection returns correct section for seeded hints
|
||||
2. _proactive_hints_injection returns empty string when no hints
|
||||
3. enrich_context includes proactive_hints key from MemoryProactive row
|
||||
4. System prompt includes proactive line when row exists + confidence >= threshold
|
||||
5. TierManager.check_feature returns True for power/team, False for free/pro
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.billing.tier_manager import tier_manager
|
||||
from app.core.deep_agent import _proactive_hints_injection
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import MemoryProactive, User
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
|
||||
USER_ID = TEST_USER_IDS["power"]
|
||||
_FERNET_KEY = Fernet.generate_key().decode()
|
||||
|
||||
|
||||
# ── DB override ───────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _override_db(db_session):
|
||||
async def _gen():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _gen
|
||||
yield
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def user_with_key(db_session):
|
||||
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = _FERNET_KEY
|
||||
await db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def _enc(plaintext: str) -> str:
|
||||
return Fernet(_FERNET_KEY.encode()).encrypt(plaintext.encode()).decode()
|
||||
|
||||
|
||||
# ── _proactive_hints_injection unit tests ─────────────────────────────────────
|
||||
|
||||
def test_proactive_hints_injection_with_hints():
|
||||
context = {"proactive_hints": ["Works late on Thursdays", "Prefers bullet points"]}
|
||||
result = _proactive_hints_injection(context)
|
||||
assert "I noticed" in result
|
||||
assert "Works late on Thursdays" in result
|
||||
assert "Prefers bullet points" in result
|
||||
|
||||
|
||||
def test_proactive_hints_injection_empty():
|
||||
assert _proactive_hints_injection({}) == ""
|
||||
assert _proactive_hints_injection({"proactive_hints": []}) == ""
|
||||
assert _proactive_hints_injection({"proactive_hints": None}) == ""
|
||||
|
||||
|
||||
def test_proactive_hints_injection_truncates_long_hints():
|
||||
hints = ["x" * 200] * 10
|
||||
result = _proactive_hints_injection({"proactive_hints": hints})
|
||||
assert len(result) <= 600
|
||||
assert result.endswith("...")
|
||||
|
||||
|
||||
# ── enrich_context includes proactive hints ───────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
||||
pattern = "Always checks tasks before meetings"
|
||||
db_session.add(MemoryProactive(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=USER_ID,
|
||||
pattern_encrypted=_enc(pattern),
|
||||
confidence=0.8,
|
||||
source="inferred",
|
||||
))
|
||||
await db_session.commit()
|
||||
|
||||
middleware = MemoryMiddleware(db_session)
|
||||
ctx = await middleware.enrich_context(USER_ID, "test message")
|
||||
|
||||
assert "proactive_hints" in ctx
|
||||
assert pattern in ctx["proactive_hints"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enrich_context_excludes_low_confidence_proactive(db_session, user_with_key):
|
||||
pattern = "Low confidence pattern"
|
||||
db_session.add(MemoryProactive(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=USER_ID,
|
||||
pattern_encrypted=_enc(pattern),
|
||||
confidence=0.1,
|
||||
source="inferred",
|
||||
))
|
||||
await db_session.commit()
|
||||
|
||||
middleware = MemoryMiddleware(db_session)
|
||||
ctx = await middleware.enrich_context(USER_ID, "test message")
|
||||
|
||||
hints = ctx.get("proactive_hints", [])
|
||||
assert pattern not in hints
|
||||
|
||||
|
||||
# ── proactive hints appear in system prompt string ───────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proactive_hints_in_system_prompt_string(db_session, user_with_key):
|
||||
pattern = "Frequently requests end-of-day summaries"
|
||||
db_session.add(MemoryProactive(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=USER_ID,
|
||||
pattern_encrypted=_enc(pattern),
|
||||
confidence=0.75,
|
||||
source="inferred",
|
||||
))
|
||||
await db_session.commit()
|
||||
|
||||
middleware = MemoryMiddleware(db_session)
|
||||
ctx = await middleware.enrich_context(USER_ID, "summarize my day")
|
||||
|
||||
system_prompt_suffix = _proactive_hints_injection(ctx)
|
||||
assert pattern in system_prompt_suffix
|
||||
|
||||
|
||||
# ── Tier gate ─────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.parametrize("tier,expected", [
|
||||
("free", False),
|
||||
("pro", False),
|
||||
("power", True),
|
||||
("team", True),
|
||||
])
|
||||
def test_proactive_mining_tier_gate(tier, expected):
|
||||
assert tier_manager.check_feature(tier, "proactive_mining") == expected
|
||||
220
tests/test_memory_relations.py
Normal file
220
tests/test_memory_relations.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Tests for Phase 3 — relational tier (Mem0g-light).
|
||||
|
||||
Coverage:
|
||||
1. upsert_relation inserts a row and query_relations returns it
|
||||
2. upsert_relation updates existing row on duplicate (subject/predicate/object)
|
||||
3. tier gating: Free user gets empty list from query_relations + enrich_context
|
||||
4. enrich_context includes relational_memory key for Pro user
|
||||
5. decay_relations decays confidence and prunes rows below threshold
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.memory_maintenance import decay_relations
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import MemoryRelation, User
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
PRO_USER_ID = TEST_USER_IDS["pro"]
|
||||
FREE_USER_ID = TEST_USER_IDS["free"]
|
||||
_FERNET_KEY = Fernet.generate_key().decode()
|
||||
|
||||
|
||||
# ── DB override ───────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _override_db(db_session):
|
||||
async def _gen():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _gen
|
||||
yield
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def pro_user_with_key(db_session):
|
||||
"""Set encryption_key on the pro test user so Fernet works."""
|
||||
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = _FERNET_KEY
|
||||
await db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def free_user_with_key(db_session):
|
||||
"""Set encryption_key on the free test user."""
|
||||
result = await db_session.execute(select(User).where(User.id == FREE_USER_ID))
|
||||
user = result.scalar_one()
|
||||
user.encryption_key = _FERNET_KEY
|
||||
await db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_relation_inserts_and_queries(db_session, pro_user_with_key):
|
||||
"""upsert_relation inserts a row; query_relations returns it."""
|
||||
mm = MemoryMiddleware(db_session)
|
||||
await mm.upsert_relation(
|
||||
PRO_USER_ID,
|
||||
subject="Giulia",
|
||||
subject_type="person",
|
||||
predicate="works_at",
|
||||
object_="Acme Corp",
|
||||
object_type="company",
|
||||
confidence=0.9,
|
||||
)
|
||||
rows = await mm.query_relations(PRO_USER_ID, subject="Giulia")
|
||||
assert len(rows) == 1
|
||||
assert rows[0].subject_label == "Giulia"
|
||||
assert rows[0].predicate == "works_at"
|
||||
assert rows[0].object_label == "Acme Corp"
|
||||
assert abs(rows[0].confidence - 0.9) < 0.001
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_relation_updates_on_duplicate(db_session, pro_user_with_key):
|
||||
"""Second upsert on same triple updates confidence and last_confirmed_at."""
|
||||
mm = MemoryMiddleware(db_session)
|
||||
await mm.upsert_relation(
|
||||
PRO_USER_ID,
|
||||
subject="Marco",
|
||||
subject_type="person",
|
||||
predicate="stakeholder_of",
|
||||
object_="Project Nexus",
|
||||
object_type="project",
|
||||
confidence=0.7,
|
||||
)
|
||||
await mm.upsert_relation(
|
||||
PRO_USER_ID,
|
||||
subject="Marco",
|
||||
subject_type="person",
|
||||
predicate="stakeholder_of",
|
||||
object_="Project Nexus",
|
||||
object_type="project",
|
||||
confidence=0.95,
|
||||
)
|
||||
rows = await mm.query_relations(PRO_USER_ID, subject="Marco")
|
||||
# Only one row despite two upserts
|
||||
assert len(rows) == 1
|
||||
assert abs(rows[0].confidence - 0.95) < 0.001
|
||||
assert rows[0].last_confirmed_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_free_tier_relation_skipped(db_session, free_user_with_key):
|
||||
"""Free user: upsert_relation is silently skipped (no row created)."""
|
||||
mm = MemoryMiddleware(db_session)
|
||||
await mm.upsert_relation(
|
||||
FREE_USER_ID,
|
||||
subject="Alice",
|
||||
subject_type="person",
|
||||
predicate="reports_to",
|
||||
object_="Bob",
|
||||
object_type="person",
|
||||
confidence=0.8,
|
||||
)
|
||||
rows = await mm.query_relations(FREE_USER_ID, subject="Alice")
|
||||
assert rows == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enrich_context_includes_relational_memory(db_session, pro_user_with_key):
|
||||
"""enrich_context includes relational_memory key for Pro user."""
|
||||
mm = MemoryMiddleware(db_session)
|
||||
await mm.upsert_relation(
|
||||
PRO_USER_ID,
|
||||
subject="Elena",
|
||||
subject_type="person",
|
||||
predicate="cfo_of",
|
||||
object_="StartupXYZ",
|
||||
object_type="company",
|
||||
confidence=0.85,
|
||||
)
|
||||
|
||||
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
|
||||
ctx = await mm.enrich_context(PRO_USER_ID, "who is Elena?")
|
||||
|
||||
assert "relational_memory" in ctx
|
||||
assert any("Elena" in r for r in ctx["relational_memory"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enrich_context_relational_empty_for_free(db_session, free_user_with_key):
|
||||
"""Free user: relational_memory is empty list in enrich_context."""
|
||||
mm = MemoryMiddleware(db_session)
|
||||
|
||||
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
|
||||
ctx = await mm.enrich_context(FREE_USER_ID, "test message")
|
||||
|
||||
assert ctx.get("relational_memory") == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decay_relations_reduces_confidence(db_session, pro_user_with_key):
|
||||
"""decay_relations reduces confidence on stale rows."""
|
||||
old_date = datetime.now(timezone.utc) - timedelta(days=35)
|
||||
row = MemoryRelation(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=PRO_USER_ID,
|
||||
subject_label="OldContact",
|
||||
subject_type="person",
|
||||
predicate="knows",
|
||||
object_label="SomeProject",
|
||||
object_type="project",
|
||||
confidence=0.8,
|
||||
last_confirmed_at=old_date,
|
||||
)
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
await decay_relations(db_session, PRO_USER_ID)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryRelation).where(MemoryRelation.subject_label == "OldContact")
|
||||
)
|
||||
updated = result.scalar_one_or_none()
|
||||
assert updated is not None
|
||||
assert updated.confidence < 0.8
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decay_relations_prunes_low_confidence(db_session, pro_user_with_key):
|
||||
"""decay_relations deletes rows whose confidence drops below 0.2 threshold."""
|
||||
# Start at 0.21 with 60-day-old last_confirmed_at → two decay periods → 0.21 * 0.95^2 ≈ 0.19 → pruned
|
||||
old_date = datetime.now(timezone.utc) - timedelta(days=65)
|
||||
row = MemoryRelation(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=PRO_USER_ID,
|
||||
subject_label="ExpiredContact",
|
||||
subject_type="person",
|
||||
predicate="used_to_work_with",
|
||||
object_label="OldCorp",
|
||||
object_type="company",
|
||||
confidence=0.21,
|
||||
last_confirmed_at=old_date,
|
||||
)
|
||||
db_session.add(row)
|
||||
await db_session.commit()
|
||||
|
||||
await decay_relations(db_session, PRO_USER_ID)
|
||||
|
||||
result = await db_session.execute(
|
||||
select(MemoryRelation).where(MemoryRelation.subject_label == "ExpiredContact")
|
||||
)
|
||||
pruned = result.scalar_one_or_none()
|
||||
assert pruned is None
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import pytest
|
||||
|
||||
from app.core.output_formatter import StreamFormatter
|
||||
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
||||
from app.schemas import WsStreamEnd, WsStreamStart, WsStreamText
|
||||
|
||||
|
||||
async def _stream(*events: tuple[str, object]):
|
||||
@@ -36,29 +36,6 @@ async def test_stream_formatter_text_stream() -> None:
|
||||
assert isinstance(frames[-1], WsStreamEnd)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_formatter_floating_domain_first() -> None:
|
||||
formatter = StreamFormatter(request_id="req-2")
|
||||
frames = await _collect(
|
||||
formatter,
|
||||
_stream(
|
||||
(
|
||||
"floating_domain",
|
||||
{"type": "node", "id": "n-1", "section": None},
|
||||
),
|
||||
("token", "Summary"),
|
||||
),
|
||||
)
|
||||
|
||||
assert isinstance(frames[0], WsFloatingDomain)
|
||||
assert frames[0].domain.type == "node"
|
||||
assert frames[0].domain.id == "n-1"
|
||||
assert isinstance(frames[1], WsStreamStart)
|
||||
assert isinstance(frames[2], WsStreamText)
|
||||
assert frames[2].chunk == "Summary"
|
||||
assert isinstance(frames[-1], WsStreamEnd)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_formatter_ignores_unknown_events() -> None:
|
||||
formatter = StreamFormatter(request_id="req-3")
|
||||
|
||||
85
tests/test_run_contextual.py
Normal file
85
tests/test_run_contextual.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Tests for run_contextual_stream.
|
||||
|
||||
These tests monkeypatch _run_single_agent_stream (the actual internal runner)
|
||||
rather than the plan's fictional _run_agent_loop, matching the real
|
||||
deep_agent.py architecture.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from app.schemas.contextual import ContextualScope
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_contextual_stream_includes_scope_block(monkeypatch):
|
||||
"""run_contextual_stream must inject the scope block into the system prompt
|
||||
and include get_page_details in the tool list while excluding note-edit tools."""
|
||||
import app.core.deep_agent as deep_agent
|
||||
|
||||
captured = {}
|
||||
|
||||
async def fake_stream(
|
||||
*,
|
||||
user_id,
|
||||
system_prompt,
|
||||
message,
|
||||
context,
|
||||
agent_name="agent",
|
||||
tools=None,
|
||||
conversation_history=None,
|
||||
**kwargs,
|
||||
):
|
||||
captured["sys"] = system_prompt
|
||||
captured["tool_names"] = [getattr(t, "name", str(t)) for t in (tools or [])]
|
||||
captured["agent_name"] = agent_name
|
||||
# Async generator that yields nothing — still satisfies the protocol.
|
||||
if False:
|
||||
yield # pragma: no cover
|
||||
|
||||
monkeypatch.setattr(deep_agent, "_run_single_agent_stream", fake_stream)
|
||||
|
||||
scope = ContextualScope(
|
||||
page="project",
|
||||
entity_type="project",
|
||||
entity_id="p1",
|
||||
entity_name="Acme",
|
||||
counts={"tasks": 1, "notes": 0, "milestones": 0},
|
||||
)
|
||||
|
||||
context = {
|
||||
"conversation_history": [],
|
||||
"_debug": {"session_id": "s1"},
|
||||
}
|
||||
|
||||
results = []
|
||||
async for item in deep_agent.run_contextual_stream(
|
||||
user_id="user1",
|
||||
message="hi",
|
||||
context=context,
|
||||
scope=scope,
|
||||
):
|
||||
results.append(item)
|
||||
|
||||
assert "Acme" in captured["sys"], "scope block must appear in system prompt"
|
||||
assert "Current view" in captured["sys"], "section header must be present"
|
||||
|
||||
names = captured["tool_names"]
|
||||
assert "get_page_details" in names, "get_page_details tool must be included"
|
||||
|
||||
# Entity-create tools: at least one of these must be present.
|
||||
assert any(n in names for n in ("create_task", "create_note", "update_task")), (
|
||||
"at least one entity-create tool must be present"
|
||||
)
|
||||
|
||||
assert "create_timeline" in names, "create_timeline tool must be included"
|
||||
|
||||
# Note edit tools must NOT be exposed.
|
||||
assert "propose_note_edit" not in names, "propose_note_edit must be excluded"
|
||||
|
||||
# Legacy read tools must be excluded — they return shallow snapshots and
|
||||
# cause the agent to under-answer (see trace 0b46841484ba7d024ed9f8d5ac8b1df0).
|
||||
assert "list_projects" not in names, "list_projects must be excluded (legacy read)"
|
||||
assert "get_project" not in names, "get_project must be excluded (legacy read)"
|
||||
assert "list_tasks" not in names, "list_tasks must be excluded (legacy read)"
|
||||
assert "get_task" not in names, "get_task must be excluded (legacy read)"
|
||||
assert "list_notes" not in names, "list_notes must be excluded (legacy read)"
|
||||
assert "get_note" not in names, "get_note must be excluded (legacy read)"
|
||||
@@ -4,12 +4,8 @@ import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas import (
|
||||
WsDomain,
|
||||
WsFrameType,
|
||||
WsHomeRequest,
|
||||
WsFloatingDomain,
|
||||
WsFloatingRequest,
|
||||
WsFloatingScope,
|
||||
WsStreamEnd,
|
||||
WsStreamStart,
|
||||
WsStreamText,
|
||||
@@ -22,11 +18,9 @@ from app.schemas import (
|
||||
def test_v3_frame_types_exist():
|
||||
v3_types = [
|
||||
"home_request",
|
||||
"floating_request",
|
||||
"stream_start",
|
||||
"stream_text",
|
||||
"stream_end",
|
||||
"floating_domain",
|
||||
"data_request",
|
||||
"data_response",
|
||||
"mutation",
|
||||
@@ -45,9 +39,6 @@ def test_v2_frame_types_still_exist():
|
||||
"tool_result",
|
||||
"final",
|
||||
"ping",
|
||||
"agent_run",
|
||||
"agent_data",
|
||||
"agent_complete",
|
||||
"device_hello",
|
||||
]
|
||||
for name in v2_types:
|
||||
@@ -89,51 +80,6 @@ def test_home_request_requires_message():
|
||||
WsHomeRequest.model_validate({"type": "home_request"})
|
||||
|
||||
|
||||
# ── WsFloatingRequest ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_floating_request_basic():
|
||||
frame = WsFloatingRequest(
|
||||
message="Summarise",
|
||||
scope=WsFloatingScope(type="task", id="task-123"),
|
||||
)
|
||||
assert frame.type == WsFrameType.floating_request
|
||||
assert frame.scope.type == "task"
|
||||
assert frame.scope.id == "task-123"
|
||||
|
||||
|
||||
def test_floating_request_scope_without_id():
|
||||
frame = WsFloatingRequest(
|
||||
message="Show all",
|
||||
scope=WsFloatingScope(type="project"),
|
||||
)
|
||||
assert frame.scope.id is None
|
||||
|
||||
|
||||
def test_floating_request_serializes():
|
||||
frame = WsFloatingRequest(
|
||||
message="Test",
|
||||
scope=WsFloatingScope(type="note", id="n-1"),
|
||||
)
|
||||
data = frame.model_dump()
|
||||
assert data["type"] == "floating_request"
|
||||
assert data["scope"]["type"] == "note"
|
||||
assert data["scope"]["id"] == "n-1"
|
||||
|
||||
|
||||
def test_floating_request_invalid_scope_type():
|
||||
with pytest.raises(ValidationError):
|
||||
WsFloatingRequest(
|
||||
message="X",
|
||||
scope=WsFloatingScope(type="unknown"), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def test_floating_request_requires_scope():
|
||||
with pytest.raises(ValidationError):
|
||||
WsFloatingRequest.model_validate({"type": "floating_request", "message": "X"})
|
||||
|
||||
|
||||
# ── WsStreamStart ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -192,51 +138,3 @@ def test_stream_end_deserializes():
|
||||
assert frame.request_id == "r3"
|
||||
|
||||
|
||||
# ── WsFloatingDomain ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_floating_domain_tasks():
|
||||
frame = WsFloatingDomain(request_id="r1", domain=WsDomain(type="task"))
|
||||
assert frame.type == WsFrameType.floating_domain
|
||||
assert frame.domain.type == "task"
|
||||
|
||||
|
||||
def test_floating_domain_valid_domains():
|
||||
frame = WsFloatingDomain(
|
||||
request_id="r1",
|
||||
domain=WsDomain(type="project", id="213213-312321-312312-421321", section="task"),
|
||||
)
|
||||
assert frame.domain.type == "project"
|
||||
assert frame.domain.id == "213213-312321-312312-421321"
|
||||
assert frame.domain.section == "task"
|
||||
|
||||
|
||||
def test_floating_domain_object_valid():
|
||||
frame = WsFloatingDomain(
|
||||
request_id="r1",
|
||||
domain=WsDomain(type="project", id="p1", section="task"),
|
||||
)
|
||||
assert frame.domain.type == "project"
|
||||
|
||||
|
||||
def test_floating_domain_serializes():
|
||||
d = WsFloatingDomain(
|
||||
request_id="r1",
|
||||
domain=WsDomain(type="timeline"),
|
||||
).model_dump()
|
||||
assert d == {
|
||||
"type": "floating_domain",
|
||||
"request_id": "r1",
|
||||
"domain": {"type": "timeline", "id": None, "section": None},
|
||||
}
|
||||
|
||||
|
||||
def test_floating_domain_deserializes():
|
||||
raw = {
|
||||
"type": "floating_domain",
|
||||
"request_id": "r1",
|
||||
"domain": {"type": "node", "id": "n-1", "section": None},
|
||||
}
|
||||
frame = WsFloatingDomain.model_validate(raw)
|
||||
assert frame.domain.type == "node"
|
||||
assert frame.domain.id == "n-1"
|
||||
|
||||
196
tests/test_ws_index_session.py
Normal file
196
tests/test_ws_index_session.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Tests for WS folder index_session handlers (Task 9).
|
||||
|
||||
Tests the three handler functions directly with a minimal fake WebSocket so
|
||||
no real WS connection or LLM call is made.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.api.routes.device_ws import (
|
||||
_handle_index_session_start,
|
||||
_handle_index_file_batch,
|
||||
_handle_index_session_cancel,
|
||||
_index_sessions,
|
||||
)
|
||||
from app.billing.quota import add_token_usage
|
||||
from app.core.folder_indexer import IndexResult
|
||||
from app.models import MonthlyTokenUsage
|
||||
from app.schemas import WsFrameType
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
USER_ID = TEST_USER_IDS["free"]
|
||||
POWER_USER_ID = TEST_USER_IDS["power"]
|
||||
|
||||
|
||||
# ── Fake WebSocket ────────────────────────────────────────────────────
|
||||
|
||||
class _FakeWebSocket:
|
||||
"""Minimal WebSocket stand-in that records send_text calls."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.sent: list[dict] = []
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
self.sent.append(json.loads(text))
|
||||
|
||||
def sent_types(self) -> list[str]:
|
||||
return [f["type"] for f in self.sent]
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _make_session_id() -> str:
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def _fake_summarize_text_factory(summary: str = "A test summary.", tokens: int = 100):
|
||||
"""Return an AsyncMock that resolves to a fixed IndexResult."""
|
||||
async def _fake(**kwargs) -> IndexResult:
|
||||
return IndexResult(summary=summary, tokens_used=tokens)
|
||||
return _fake
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def _clean_sessions():
|
||||
"""Ensure _index_sessions is empty before and after each test."""
|
||||
_index_sessions.clear()
|
||||
yield
|
||||
_index_sessions.clear()
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────────
|
||||
|
||||
async def test_index_session_happy_path(db_session):
|
||||
"""start + batch of 2 text files → 2 index_file_result + 1 progress + 1 done(completed)."""
|
||||
ws = _FakeWebSocket()
|
||||
session_id = _make_session_id()
|
||||
|
||||
# Register session.
|
||||
await _handle_index_session_start(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"projectId": "proj-1",
|
||||
"totalFiles": 2,
|
||||
})
|
||||
|
||||
# Verify session was registered.
|
||||
assert session_id in _index_sessions
|
||||
assert _index_sessions[session_id]["total"] == 2
|
||||
assert _index_sessions[session_id]["processed"] == 0
|
||||
# No response frames expected for session_start.
|
||||
assert ws.sent == []
|
||||
|
||||
# Send batch of 2 text files — patch summarize_text so no LLM call needed.
|
||||
with patch(
|
||||
"app.api.routes.device_ws._handle_index_file_batch.__globals__",
|
||||
# We patch the module-level function in folder_indexer instead:
|
||||
) if False else patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory()):
|
||||
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
|
||||
# Wire db_session into the context manager.
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_async_session.return_value = mock_cm
|
||||
|
||||
await _handle_index_file_batch(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"files": [
|
||||
{"relPath": "README.md", "kind": "text", "content": "hello", "ext": ".md"},
|
||||
{"relPath": "notes.txt", "kind": "text", "content": "world", "ext": ".txt"},
|
||||
],
|
||||
})
|
||||
|
||||
types = ws.sent_types()
|
||||
# Expect 2 file results + 1 progress + 1 done(completed).
|
||||
assert types.count(WsFrameType.index_file_result) == 2
|
||||
assert types.count(WsFrameType.index_session_progress) == 1
|
||||
assert types.count(WsFrameType.index_session_done) == 1
|
||||
|
||||
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||
assert done_frame["status"] == "completed"
|
||||
|
||||
progress_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_progress)
|
||||
assert progress_frame["processed"] == 2
|
||||
assert progress_frame["total"] == 2
|
||||
|
||||
# Verify session cleaned up.
|
||||
assert session_id not in _index_sessions
|
||||
|
||||
|
||||
async def test_index_session_cancel(db_session):
|
||||
"""start then cancel → index_session_done(cancelled)."""
|
||||
ws = _FakeWebSocket()
|
||||
session_id = _make_session_id()
|
||||
|
||||
await _handle_index_session_start(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"totalFiles": 5,
|
||||
})
|
||||
assert session_id in _index_sessions
|
||||
|
||||
await _handle_index_session_cancel(ws, {"sessionId": session_id})
|
||||
|
||||
types = ws.sent_types()
|
||||
assert WsFrameType.index_session_done in types
|
||||
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||
assert done_frame["status"] == "cancelled"
|
||||
|
||||
# Session should be cleaned up.
|
||||
assert session_id not in _index_sessions
|
||||
|
||||
|
||||
async def test_index_session_quota_exceeded(db_session):
|
||||
"""Pre-fill usage to cap → batch one file → index_session_done(quota_exceeded)."""
|
||||
ws = _FakeWebSocket()
|
||||
session_id = _make_session_id()
|
||||
|
||||
# Pre-fill monthly token usage to the free-tier cap (100_000).
|
||||
ym = datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
db_session.add(MonthlyTokenUsage(
|
||||
user_id=USER_ID,
|
||||
year_month=ym,
|
||||
feature="folder_index",
|
||||
tokens_used=100_000, # free tier cap exactly
|
||||
))
|
||||
await db_session.commit()
|
||||
|
||||
await _handle_index_session_start(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"totalFiles": 1,
|
||||
})
|
||||
|
||||
with patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory(tokens=1)):
|
||||
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_async_session.return_value = mock_cm
|
||||
|
||||
await _handle_index_file_batch(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"files": [
|
||||
{"relPath": "file.md", "kind": "text", "content": "content", "ext": ".md"},
|
||||
],
|
||||
})
|
||||
|
||||
types = ws.sent_types()
|
||||
# Should have 1 file result (success) then done(quota_exceeded).
|
||||
assert WsFrameType.index_file_result in types
|
||||
assert WsFrameType.index_session_done in types
|
||||
|
||||
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||
assert done_frame["status"] == "quota_exceeded"
|
||||
|
||||
# Session should be cleaned up.
|
||||
assert session_id not in _index_sessions
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Integration tests for the unified WebSocket handler (Step 5).
|
||||
|
||||
Tests the device WS endpoint with home_request and floating_request frames,
|
||||
Tests the device WS endpoint with home_request frames,
|
||||
verifying that the correct v3 frame sequence is returned.
|
||||
|
||||
LLM calls are mocked to avoid network dependency.
|
||||
@@ -34,7 +34,7 @@ def _override_db(db_session):
|
||||
|
||||
|
||||
def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
|
||||
"""Receive frames until stream_end (or stream_end inside floating flow), or max_frames."""
|
||||
"""Receive frames until stream_end or max_frames."""
|
||||
frames = []
|
||||
for _ in range(max_frames):
|
||||
raw = ws.receive_text()
|
||||
@@ -49,11 +49,6 @@ async def _mock_home_stream(user_id, message, context):
|
||||
yield "token", "Hello"
|
||||
|
||||
|
||||
async def _mock_floating_stream(user_id, message, context):
|
||||
yield "floating_domain", {"type": "task", "id": None, "section": None}
|
||||
yield "token", "Here is a summary"
|
||||
|
||||
|
||||
# ── tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_home_request_produces_stream_frames(client):
|
||||
@@ -79,33 +74,6 @@ def test_home_request_produces_stream_frames(client):
|
||||
assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end)
|
||||
|
||||
|
||||
def test_floating_request_produces_domain_frame(client):
|
||||
"""floating_request → floating_domain first, then stream_text*, stream_end."""
|
||||
token = make_jwt("power", user_id=USER_ID)
|
||||
|
||||
with patch("app.api.routes.device_ws.run_floating_stream", side_effect=_mock_floating_stream):
|
||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||
ws.send_text(json.dumps({
|
||||
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
|
||||
}))
|
||||
ws.send_text(json.dumps({
|
||||
"type": "floating_request",
|
||||
"request_id": "p1",
|
||||
"message": "Summarize this task",
|
||||
"scope": {"type": "task", "id": "task-123"},
|
||||
}))
|
||||
frames = _recv_until_end(ws)
|
||||
|
||||
types = [f["type"] for f in frames]
|
||||
assert WsFrameType.floating_domain in types
|
||||
assert WsFrameType.stream_end in types
|
||||
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
|
||||
|
||||
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
|
||||
assert domain_frame["domain"]["type"] == "task"
|
||||
assert domain_frame["request_id"] == "p1"
|
||||
|
||||
|
||||
def test_home_request_request_id_propagated(client):
|
||||
"""request_id in home_request is echoed in all response frames."""
|
||||
token = make_jwt("power", user_id=USER_ID)
|
||||
|
||||
Reference in New Issue
Block a user