Compare commits
10 Commits
e672b58b6f
...
feature/de
| Author | SHA1 | Date | |
|---|---|---|---|
| 47bf1881e5 | |||
| 24a9c1b752 | |||
| 706bf88883 | |||
| 4ff0b27084 | |||
| 61d2a18234 | |||
| b3687719b6 | |||
| f80bdfa8f7 | |||
| 617a17db40 | |||
| 92716cb89a | |||
| cfc9d7a942 |
@@ -39,13 +39,6 @@ QDRANT_URL=
|
|||||||
QDRANT_API_KEY=
|
QDRANT_API_KEY=
|
||||||
# For local Qdrant (homelab): QDRANT_URL=http://qdrant:6333
|
# For local Qdrant (homelab): QDRANT_URL=http://qdrant:6333
|
||||||
|
|
||||||
# ── Langfuse (leave empty to disable observability) ───────────────────────────
|
|
||||||
LANGFUSE_SECRET_KEY=
|
|
||||||
LANGFUSE_PUBLIC_KEY=
|
|
||||||
# LANGFUSE_HOST=https://cloud.langfuse.com # EU (default)
|
|
||||||
# LANGFUSE_HOST=https://us.cloud.langfuse.com # US
|
|
||||||
# LANGFUSE_HOST=http://localhost:3000 # Self-hosted
|
|
||||||
|
|
||||||
# ── CORS ──────────────────────────────────────────────────────────────────────
|
# ── CORS ──────────────────────────────────────────────────────────────────────
|
||||||
# Comma-separated list parsed by Settings (override default if needed)
|
# Comma-separated list parsed by Settings (override default if needed)
|
||||||
# CORS_ORIGINS=["app://.","http://localhost:3000"]
|
# CORS_ORIGINS=["app://.","http://localhost:3000"]
|
||||||
|
|||||||
@@ -1,92 +0,0 @@
|
|||||||
"""Deprecate backend agent config tables.
|
|
||||||
|
|
||||||
The Electron client is now the source of truth for agent configuration
|
|
||||||
(directory, extract targets, batch interval, custom prompt). Backend keeps
|
|
||||||
billing checks and trigger/run logs only.
|
|
||||||
|
|
||||||
Revision ID: 9a1f2d0b6c7e
|
|
||||||
Revises: 818478c251dc
|
|
||||||
Create Date: 2026-03-16
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
|
|
||||||
revision: str = "9a1f2d0b6c7e"
|
|
||||||
down_revision: Union[str, None] = "818478c251dc"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
bind = op.get_bind()
|
|
||||||
inspector = sa.inspect(bind)
|
|
||||||
existing = set(inspector.get_table_names())
|
|
||||||
|
|
||||||
if "cloud_agent_configs" in existing:
|
|
||||||
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
|
||||||
op.drop_table("cloud_agent_configs")
|
|
||||||
|
|
||||||
if "local_agent_configs" in existing:
|
|
||||||
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
|
||||||
op.drop_table("local_agent_configs")
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.create_table(
|
|
||||||
"local_agent_configs",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("device_id", sa.String(255), nullable=False),
|
|
||||||
sa.Column("name", sa.String(255), nullable=False),
|
|
||||||
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
|
||||||
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
|
||||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
|
||||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
|
||||||
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
DO $$ BEGIN
|
|
||||||
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
|
||||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
|
||||||
END $$;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
op.create_table(
|
|
||||||
"cloud_agent_configs",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column(
|
|
||||||
"provider",
|
|
||||||
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column("name", sa.String(255), nullable=False),
|
|
||||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
|
||||||
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
|
||||||
sa.Column("filter_config", sa.JSON, nullable=True),
|
|
||||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
|
||||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
|
||||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
"""add agent_config to local_agent_configs
|
|
||||||
|
|
||||||
Revision ID: a3b9c0d1e2f3
|
|
||||||
Revises: 9a1f2d0b6c7e
|
|
||||||
Create Date: 2026-04-07 00:00:00.000000
|
|
||||||
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "a3b9c0d1e2f3"
|
|
||||||
down_revision: Union[str, None] = "9a1f2d0b6c7e"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
op.add_column(
|
|
||||||
"local_agent_configs",
|
|
||||||
sa.Column("agent_config", sa.JSON(), nullable=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.drop_column("local_agent_configs", "agent_config")
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Expose tool modules used by deep orchestrator-worker graphs."""
|
"""Agent tool modules — imported by deep_agent.py to build sub-agent graphs."""
|
||||||
|
|
||||||
from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent
|
from app.agents import timeline_agent, note_agent, project_agent, task_agent
|
||||||
|
|
||||||
__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"]
|
__all__ = ["timeline_agent", "note_agent", "project_agent", "task_agent"]
|
||||||
|
|||||||
@@ -1,85 +0,0 @@
|
|||||||
"""Filesystem agent — tools for reading local directories and files on Electron.
|
|
||||||
|
|
||||||
These tools delegate to the Electron client via ``execute_on_client()`` using
|
|
||||||
the same WS tool-call round-trip pattern as CRUD tools. The Electron app
|
|
||||||
handles actual disk I/O and responds with ``tool_result`` frames.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_directory(path: str) -> str:
|
|
||||||
"""List files and folders in a local directory on the user's device.
|
|
||||||
|
|
||||||
Returns a formatted listing of entries with name, type (file/directory),
|
|
||||||
and full path.
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="list_directory",
|
|
||||||
data={"path": path},
|
|
||||||
)
|
|
||||||
entries: list[dict[str, Any]] = result.get("entries", [])
|
|
||||||
if not entries:
|
|
||||||
return f"Directory '{path}' is empty or does not exist."
|
|
||||||
lines: list[str] = []
|
|
||||||
for entry in entries:
|
|
||||||
entry_type = entry.get("type", "unknown")
|
|
||||||
entry_name = entry.get("name", "")
|
|
||||||
entry_path = entry.get("path", "")
|
|
||||||
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
|
||||||
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def read_file_content(path: str) -> str:
|
|
||||||
"""Read the text content of a local file on the user's device.
|
|
||||||
|
|
||||||
Returns the file content as a string. Large files may be truncated
|
|
||||||
by the Electron client.
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="read_file_content",
|
|
||||||
data={"path": path},
|
|
||||||
)
|
|
||||||
content: str = result.get("content", "")
|
|
||||||
if not content:
|
|
||||||
return f"File '{path}' is empty or could not be read."
|
|
||||||
return content
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def get_file_metadata(path: str) -> str:
|
|
||||||
"""Get metadata for a local file: size, creation date, modification date, extension.
|
|
||||||
|
|
||||||
Returns a formatted summary of the file's metadata.
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="get_file_metadata",
|
|
||||||
data={"path": path},
|
|
||||||
)
|
|
||||||
size = result.get("size", "unknown")
|
|
||||||
created = result.get("createdAt", "unknown")
|
|
||||||
modified = result.get("modifiedAt", "unknown")
|
|
||||||
extension = result.get("extension", "unknown")
|
|
||||||
name = result.get("name", path)
|
|
||||||
return (
|
|
||||||
f"File: {name}\n"
|
|
||||||
f" Extension: {extension}\n"
|
|
||||||
f" Size: {size} bytes\n"
|
|
||||||
f" Created: {created}\n"
|
|
||||||
f" Modified: {modified}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
FILESYSTEM_TOOLS: list[Any] = [
|
|
||||||
list_directory,
|
|
||||||
read_file_content,
|
|
||||||
get_file_metadata,
|
|
||||||
]
|
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
"""Note agent — Markdown note management (list, get, create, update, delete)."""
|
"""Note agent — tool definitions for Markdown note CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
@@ -10,38 +9,14 @@ from langchain_core.tools import tool
|
|||||||
from app.core.llm import embed
|
from app.core.llm import embed
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_uuid(value: str) -> bool:
|
|
||||||
return bool(_UUID_RE.match(value))
|
|
||||||
|
|
||||||
NOTE_SYSTEM_PROMPT = (
|
|
||||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
|
||||||
"and delete Markdown notes in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - content is always Markdown; preserve formatting when updating\n"
|
|
||||||
" - project_id is optional; link a note to a project when mentioned\n"
|
|
||||||
" - When updating, call get_note first if you need to read existing content\n"
|
|
||||||
" before appending or replacing sections\n"
|
|
||||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
|
||||||
" when the user is working within a specific project\n"
|
|
||||||
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
|
|
||||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
|
||||||
" is already in the note (retrieved via get_note)."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_notes(project_id: str = "") -> str:
|
async def list_notes(project_id: str = "") -> str:
|
||||||
"""List notes, optionally scoped to a project by project_id."""
|
"""List notes, optionally scoped to a project by project_id."""
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="notes",
|
table="notes",
|
||||||
filters={"projectId": normalized_project_id or None},
|
filters={"projectId": project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
@@ -130,10 +105,4 @@ async def delete_note(note_id: str) -> str:
|
|||||||
return f"Note {note_id} deleted."
|
return f"Note {note_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
NOTE_TOOLS: list[Any] = [
|
|
||||||
list_notes,
|
|
||||||
get_note,
|
|
||||||
create_note,
|
|
||||||
update_note,
|
|
||||||
delete_note,
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
|
"""Project agent — tool definitions for project lifecycle CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -8,22 +8,6 @@ from langchain_core.tools import tool
|
|||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
PROJECT_SYSTEM_PROMPT = (
|
|
||||||
"You are a project management assistant. You help users create, find,\n"
|
|
||||||
"update, and archive projects in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: active, archived\n"
|
|
||||||
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
|
||||||
" - ai_summary is populated only when the user asks for a project summary;\n"
|
|
||||||
" derive it from context data — do not fabricate content\n"
|
|
||||||
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
|
||||||
" user wants a complete cross-client view including archived projects\n"
|
|
||||||
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
|
||||||
" list_projects if you only have a project name\n"
|
|
||||||
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
|
||||||
" only call delete_project when the user explicitly confirms deletion."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_projects(
|
async def list_projects(
|
||||||
@@ -133,11 +117,4 @@ async def delete_project(project_id: str) -> str:
|
|||||||
return f"Project {project_id} permanently deleted."
|
return f"Project {project_id} permanently deleted."
|
||||||
|
|
||||||
|
|
||||||
PROJECT_TOOLS: list[Any] = [
|
|
||||||
list_projects,
|
|
||||||
list_all_projects,
|
|
||||||
get_project,
|
|
||||||
create_project,
|
|
||||||
update_project,
|
|
||||||
delete_project,
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,40 +1,14 @@
|
|||||||
"""Task agent — full CRUD for tasks and task comments."""
|
"""Task agent — tool definitions for task and task comment CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
import re
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_uuid(value: str) -> bool:
|
|
||||||
return bool(_UUID_RE.match(value))
|
|
||||||
|
|
||||||
TASK_SYSTEM_PROMPT = (
|
|
||||||
"You are a task management assistant for a project workspace.\n"
|
|
||||||
"You create, update, list, and track tasks and their comments.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: todo, in_progress, done\n"
|
|
||||||
" - priority must be one of: high, medium, low\n"
|
|
||||||
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
|
||||||
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
|
||||||
" - project_id is optional; link to a project when the user mentions one\n"
|
|
||||||
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
|
||||||
" did not explicitly request; 0 otherwise\n"
|
|
||||||
" - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n"
|
|
||||||
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
|
||||||
" - For update_task, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Always confirm the action in plain, user-friendly language."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Task tools ────────────────────────────────────────────────────────
|
# ── Task tools ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -48,12 +22,11 @@ async def list_tasks(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
filters={
|
filters={
|
||||||
"projectId": normalized_project_id or None,
|
"projectId": project_id or None,
|
||||||
"status": status or None,
|
"status": status or None,
|
||||||
"search": search or None,
|
"search": search or None,
|
||||||
"orderBy": order_by or None,
|
"orderBy": order_by or None,
|
||||||
@@ -79,6 +52,7 @@ async def create_task(
|
|||||||
due_date: int = 0,
|
due_date: int = 0,
|
||||||
project_id: str = "",
|
project_id: str = "",
|
||||||
is_ai_suggested: int = 0,
|
is_ai_suggested: int = 0,
|
||||||
|
is_approved: int = 0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a new task.
|
"""Create a new task.
|
||||||
title: task title (required)
|
title: task title (required)
|
||||||
@@ -89,6 +63,7 @@ async def create_task(
|
|||||||
due_date: Unix timestamp in milliseconds; 0 means no due date
|
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||||
project_id: optional UUID of the parent project
|
project_id: optional UUID of the parent project
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
is_approved: 0 until the user confirms; 1 when confirmed
|
||||||
"""
|
"""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="insert",
|
action="insert",
|
||||||
@@ -102,6 +77,7 @@ async def create_task(
|
|||||||
"dueDate": due_date or None,
|
"dueDate": due_date or None,
|
||||||
"projectId": project_id or None,
|
"projectId": project_id or None,
|
||||||
"isAiSuggested": is_ai_suggested,
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
"isApproved": is_approved,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
@@ -121,10 +97,12 @@ async def update_task(
|
|||||||
assignees: str = "",
|
assignees: str = "",
|
||||||
due_date: int = -1,
|
due_date: int = -1,
|
||||||
project_id: str = "",
|
project_id: str = "",
|
||||||
|
is_approved: int = -1,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update fields on an existing task. Only pass fields you want to change.
|
"""Update fields on an existing task. Only pass fields you want to change.
|
||||||
task_id: the task's UUID (required)
|
task_id: the task's UUID (required)
|
||||||
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||||
|
is_approved: -1 means unchanged; 0 or 1 sets the value
|
||||||
"""
|
"""
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if title:
|
if title:
|
||||||
@@ -141,6 +119,8 @@ async def update_task(
|
|||||||
updates["dueDate"] = due_date or None
|
updates["dueDate"] = due_date or None
|
||||||
if project_id:
|
if project_id:
|
||||||
updates["projectId"] = project_id
|
updates["projectId"] = project_id
|
||||||
|
if is_approved != -1:
|
||||||
|
updates["isApproved"] = is_approved
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="update",
|
action="update",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
@@ -208,12 +188,8 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
|||||||
table="taskComments",
|
table="taskComments",
|
||||||
data={"taskId": task_id, "author": author, "content": content},
|
data={"taskId": task_id, "author": author, "content": content},
|
||||||
)
|
)
|
||||||
row = result.get("row", {})
|
row = result["row"]
|
||||||
row_author = row.get("author", author)
|
return f"Comment added by {row['author']} on task {row['taskId']} (comment id: {row['id']})."
|
||||||
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
|
|
||||||
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
|
||||||
row_comment_id = row.get("id", "unknown")
|
|
||||||
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -223,16 +199,4 @@ async def delete_task_comment(comment_id: str) -> str:
|
|||||||
return f"Comment {comment_id} deleted."
|
return f"Comment {comment_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
# ── Agent ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
TASK_TOOLS: list[Any] = [
|
|
||||||
list_tasks,
|
|
||||||
create_task,
|
|
||||||
update_task,
|
|
||||||
delete_task,
|
|
||||||
list_tasks_due_today,
|
|
||||||
list_task_comments,
|
|
||||||
add_task_comment,
|
|
||||||
delete_task_comment,
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,45 +1,21 @@
|
|||||||
"""Timeline agent — project milestone management (list, create, update, delete)."""
|
"""Timeline agent — tool definitions for project milestone CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_uuid(value: str) -> bool:
|
|
||||||
return bool(_UUID_RE.match(value))
|
|
||||||
|
|
||||||
TIMELINE_SYSTEM_PROMPT = (
|
|
||||||
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
|
||||||
"track progress on a project — they are not calendar events.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
|
||||||
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
|
|
||||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
|
||||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
|
||||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
|
||||||
" - For update_timeline, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Listing without a project_id returns all timelines across projects\n"
|
|
||||||
" - Always echo the title and formatted date in your confirmation."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_timelines(project_id: str = "") -> str:
|
async def list_timelines(project_id: str = "") -> str:
|
||||||
"""List timelines. Provide project_id to scope to a specific project."""
|
"""List timelines. Provide project_id to scope to a specific project."""
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="timelines",
|
table="timelines",
|
||||||
filters={"projectId": normalized_project_id or None},
|
filters={"projectId": project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
@@ -54,12 +30,14 @@ async def create_timeline(
|
|||||||
title: str,
|
title: str,
|
||||||
date: int,
|
date: int,
|
||||||
is_ai_suggested: int = 0,
|
is_ai_suggested: int = 0,
|
||||||
|
is_approved: int = 0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a project timeline (milestone).
|
"""Create a project timeline (milestone).
|
||||||
project_id: REQUIRED UUID of the parent project
|
project_id: REQUIRED UUID of the parent project
|
||||||
title: descriptive name for the milestone
|
title: descriptive name for the milestone
|
||||||
date: Unix timestamp in milliseconds
|
date: Unix timestamp in milliseconds
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
is_approved: 0 until the user confirms
|
||||||
"""
|
"""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="insert",
|
action="insert",
|
||||||
@@ -69,6 +47,7 @@ async def create_timeline(
|
|||||||
"title": title,
|
"title": title,
|
||||||
"date": date,
|
"date": date,
|
||||||
"isAiSuggested": is_ai_suggested,
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
"isApproved": is_approved,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
@@ -80,16 +59,20 @@ async def update_timeline(
|
|||||||
timeline_id: str,
|
timeline_id: str,
|
||||||
title: str = "",
|
title: str = "",
|
||||||
date: int = -1,
|
date: int = -1,
|
||||||
|
is_approved: int = -1,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update a timeline. Only pass fields that should change.
|
"""Update a timeline. Only pass fields that should change.
|
||||||
timeline_id: UUID of the timeline (required)
|
timeline_id: UUID of the timeline (required)
|
||||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||||
|
is_approved: -1 means unchanged; 0 or 1 sets the approval state
|
||||||
"""
|
"""
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if title:
|
if title:
|
||||||
updates["title"] = title
|
updates["title"] = title
|
||||||
if date != -1:
|
if date != -1:
|
||||||
updates["date"] = date
|
updates["date"] = date
|
||||||
|
if is_approved != -1:
|
||||||
|
updates["isApproved"] = is_approved
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="update",
|
action="update",
|
||||||
table="timelines",
|
table="timelines",
|
||||||
@@ -106,9 +89,4 @@ async def delete_timeline(timeline_id: str) -> str:
|
|||||||
return f"Timeline {timeline_id} deleted."
|
return f"Timeline {timeline_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
TIMELINE_TOOLS: list[Any] = [
|
|
||||||
list_timelines,
|
|
||||||
create_timeline,
|
|
||||||
update_timeline,
|
|
||||||
delete_timeline,
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -55,15 +55,12 @@ async def get_current_user(
|
|||||||
raise credentials_exc
|
raise credentials_exc
|
||||||
|
|
||||||
# Live tier lookup — subscription row is the authoritative source.
|
# Live tier lookup — subscription row is the authoritative source.
|
||||||
# In dev, fall back to 'power' (unlimited) so quota limits don't
|
|
||||||
# block local development when no Stripe subscription exists.
|
|
||||||
from app.models import Subscription, User # noqa: PLC0415
|
from app.models import Subscription, User # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
tier: str = result.scalar_one_or_none() or "free"
|
||||||
tier: str = result.scalar_one_or_none() or default_tier
|
|
||||||
|
|
||||||
# Fetch name/surname from user row.
|
# Fetch name/surname from user row.
|
||||||
user_result = await db.execute(
|
user_result = await db.execute(
|
||||||
|
|||||||
@@ -1,72 +1,74 @@
|
|||||||
"""Chatbot Journey — WS-based guided conversation to build an AgentConfig.
|
"""Chatbot Journey endpoints — guided conversation to build an agent prompt_template.
|
||||||
|
|
||||||
The journey is driven entirely through WebSocket frames (no REST endpoints).
|
Endpoints:
|
||||||
The device WS handler dispatches ``journey_start`` and ``journey_message``
|
POST /agents/journey/start — start a new journey session
|
||||||
frames to the functions exported here.
|
POST /agents/journey/message — continue the conversation
|
||||||
|
|
||||||
|
Sessions are stored in-memory with a 30-minute TTL. Stale entries are
|
||||||
|
cleaned up lazily on access. Upgrade to Redis for multi-instance deployments.
|
||||||
|
|
||||||
Journey flow:
|
Journey flow:
|
||||||
1. FE sends ``journey_start`` frame with basic agent info (directory,
|
1. Client sends ``{ agent_type, agent_id? }`` to ``/start``.
|
||||||
data_types, schedule).
|
2. Server creates a session, calls the LLM with a contextual system prompt,
|
||||||
2. Server creates an in-memory session, sets up a WS executor so the
|
and returns the first question.
|
||||||
setup LLM can use file-system tools, does a first directory scrape,
|
3. Client sends follow-up messages to ``/message``.
|
||||||
and sends back a ``journey_reply`` with the first question.
|
4. After 3-5 turns the LLM wraps up by emitting a ``prompt_template`` block
|
||||||
3. FE sends ``journey_message`` frames for each user reply.
|
delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
||||||
4. Server appends the user message, calls the LLM (which may read files
|
5. Server parses the block, sets ``done=True``, and returns the template.
|
||||||
via tools), and sends back a ``journey_reply``.
|
|
||||||
5. After 3-5 turns the LLM wraps up by emitting an ``AgentConfig`` JSON
|
The ``prompt_template`` from the final response is meant to be stored in
|
||||||
block delimited by ``AGENT_CONFIG_START`` / ``AGENT_CONFIG_END``.
|
``LocalAgentConfig.prompt_template`` or ``CloudAgentConfig.prompt_template``
|
||||||
6. Server parses and validates the JSON with Pydantic, sends
|
by the Electron client (via the agent CRUD endpoints).
|
||||||
``journey_reply`` with ``done=True`` and the serialised config.
|
|
||||||
FE stores it locally.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
from app.api.deps import get_current_user
|
||||||
from app.config.settings import settings
|
|
||||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
|
|
||||||
from app.core.llm import get_llm
|
from app.core.llm import get_llm
|
||||||
from app.schemas import AgentConfig
|
from app.db import get_session
|
||||||
|
from app.models import CloudAgentConfig, LocalAgentConfig
|
||||||
|
from app.schemas import (
|
||||||
|
JourneyMessageRequest,
|
||||||
|
JourneyResponse,
|
||||||
|
JourneyStartRequest,
|
||||||
|
UserProfile,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/agents/journey", tags=["agents"])
|
||||||
|
|
||||||
# ── Session TTL ───────────────────────────────────────────────────────────
|
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||||
|
|
||||||
# Sentinel strings used to delimit the LLM-produced AgentConfig JSON.
|
# Sentinel strings used to delimit the LLM-produced prompt_template.
|
||||||
_CONFIG_START = "AGENT_CONFIG_START"
|
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
||||||
_CONFIG_END = "AGENT_CONFIG_END"
|
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
||||||
|
|
||||||
# Minimum turns before we consider nudging the LLM to wrap up.
|
# Maximum number of conversation turns before the LLM is nudged to wrap up.
|
||||||
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
_MAX_TURNS: int = 5
|
||||||
# Hard cap to avoid infinite loops (safety net, not the primary stopping criterion).
|
|
||||||
_MAX_TURNS: int = 15
|
|
||||||
# Max tool-calling steps per LLM invocation.
|
|
||||||
_MAX_TOOL_STEPS: int = 6
|
|
||||||
|
|
||||||
# ── In-memory session store ───────────────────────────────────────────────
|
# ── In-memory session store ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class JourneySession:
|
class _JourneySession:
|
||||||
session_id: str
|
session_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
agent_type: str # "local" | "cloud"
|
agent_type: str # "local" | "cloud"
|
||||||
directory: str
|
|
||||||
data_types: list[str]
|
|
||||||
history: list[dict[str, Any]] = field(default_factory=list)
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
system_prompt: str = ""
|
|
||||||
langfuse_prompt: Any = None
|
|
||||||
created_at: float = field(default_factory=time.monotonic)
|
created_at: float = field(default_factory=time.monotonic)
|
||||||
|
|
||||||
def is_expired(self) -> bool:
|
def is_expired(self) -> bool:
|
||||||
@@ -74,182 +76,103 @@ class JourneySession:
|
|||||||
|
|
||||||
|
|
||||||
# session_id → session
|
# session_id → session
|
||||||
_sessions: dict[str, JourneySession] = {}
|
_sessions: dict[str, _JourneySession] = {}
|
||||||
|
|
||||||
|
|
||||||
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
def _get_session(session_id: str, user_id: str) -> _JourneySession:
|
||||||
"""Retrieve session; return None on missing, expired, or wrong owner."""
|
"""Retrieve session; raise 404 on missing, expired, or wrong owner."""
|
||||||
s = _sessions.get(session_id)
|
s = _sessions.get(session_id)
|
||||||
if s is None or s.is_expired():
|
if s is None or s.is_expired():
|
||||||
_sessions.pop(session_id, None)
|
_sessions.pop(session_id, None)
|
||||||
return None
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
||||||
if s.user_id != user_id:
|
if s.user_id != user_id:
|
||||||
return None
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
# ── System prompt ─────────────────────────────────────────────────────────
|
# ── System prompt builder ─────────────────────────────────────────────────
|
||||||
|
|
||||||
_JOURNEY_SYSTEM_PROMPT = """\
|
_LOCAL_PREAMBLE = """\
|
||||||
|
What kind of files are in the directories you want to monitor? \
|
||||||
|
(for example: emails saved as .eml, documents in .pdf or .txt, markdown notes, etc.)"""
|
||||||
|
|
||||||
|
_CLOUD_PREAMBLE = """\
|
||||||
|
What kind of emails or messages should I look for? \
|
||||||
|
(for example: client communications, invoices, meeting notes, project updates, etc.)"""
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT_TEMPLATE = """\
|
||||||
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||||
Your job is to understand what files the user has in their directory and produce a
|
Your job is to understand exactly what data the user wants to extract from their {source_description} \
|
||||||
structured AgentConfig JSON that the extraction agent will use as its instruction set.
|
and produce a detailed prompt_template that a separate AI will use as its instruction set.
|
||||||
|
|
||||||
You have access to file-system tools to explore the user's directory:
|
Ask concise, focused questions one at a time. Cover these topics (not necessarily in this order):
|
||||||
- list_directory: see folder structure and file names
|
1. The type and format of the source content.
|
||||||
- read_file_content: peek at a file's content
|
2. Which data types to extract: tasks, notes, timelines, and/or projects.
|
||||||
- get_file_metadata: check file size, extension, dates
|
3. How fields should be mapped (e.g. email subject → task title).
|
||||||
|
4. Priority or status rules (e.g. "urgent" keyword → high priority).
|
||||||
|
5. Any special handling, date extraction, or exclusions.
|
||||||
|
|
||||||
The user's configured directory is: {directory}
|
After 3-5 questions (when you have enough information), output the final prompt_template between \
|
||||||
Target data types: {data_types}
|
these exact markers on their own lines:
|
||||||
|
|
||||||
## Your process
|
{template_start}
|
||||||
|
<the complete extraction prompt here>
|
||||||
|
{template_end}
|
||||||
|
|
||||||
### Step 1 — Explore the directory
|
The prompt_template must be a self-contained instruction for an AI that receives a document/email/message \
|
||||||
Use list_directory and read_file_content to understand what types of files are present
|
and must return a JSON array of records in this shape:
|
||||||
(HTML emails, plain-text documents, CSVs, etc.).
|
[{{ "table": "<tasks|notes|timelines|projects>", "data": {{ <field: value> }} }}, ...]
|
||||||
|
|
||||||
### Step 2 — Identify content types
|
|
||||||
For each distinct file type found, decide:
|
|
||||||
- A short id (e.g. "email_html", "plain_text", "csv")
|
|
||||||
- Which preprocessing handler to use: "email_html" for HTML emails, "generic" for everything else
|
|
||||||
- A human-readable label and optional detection_hint
|
|
||||||
|
|
||||||
### Step 3 — Ask focused questions (one at a time)
|
|
||||||
Cover these topics based on what you discovered:
|
|
||||||
1. How to map content to entity types (task / note / timeline entry)
|
|
||||||
2. Field mapping rules (e.g. email Subject → task title, filename → note title)
|
|
||||||
3. Priority or status rules (e.g. "urgent" in subject → high priority)
|
|
||||||
4. Date extraction (e.g. "by Friday" → dueDate)
|
|
||||||
5. Exclusion rules (e.g. skip newsletters, skip files with no project match)
|
|
||||||
|
|
||||||
### Step 4 — Produce the AgentConfig JSON
|
|
||||||
Once you are ≥ 90% confident, output the final config between these exact markers
|
|
||||||
(each on its own line):
|
|
||||||
|
|
||||||
{config_start}
|
|
||||||
{{
|
|
||||||
"content_types": [
|
|
||||||
{{
|
|
||||||
"id": "email_html",
|
|
||||||
"label": "Email HTML",
|
|
||||||
"detection_hint": "HTML file with From/To/Subject headers",
|
|
||||||
"preprocessing": "email_html",
|
|
||||||
"extraction_prompt": "Detailed extraction instructions for this content type..."
|
|
||||||
}}
|
|
||||||
],
|
|
||||||
"global_rules": [
|
|
||||||
"If the file cannot be matched to any project, do not create any entity."
|
|
||||||
],
|
|
||||||
"data_types": {data_types_json}
|
|
||||||
}}
|
|
||||||
{config_end}
|
|
||||||
|
|
||||||
## Rules for the extraction_prompt field
|
|
||||||
- Describe when to create a task vs note vs timeline entry (be specific and concrete)
|
|
||||||
- Include field mapping rules based on what you found in the directory
|
|
||||||
- Include priority/status/date rules if applicable
|
|
||||||
- Do NOT include projectId logic — the runner handles project assignment automatically
|
|
||||||
- Do NOT mention isAiSuggested — the runner always sets it to 1
|
|
||||||
|
|
||||||
## Constraints
|
|
||||||
- Never ask about projects, projectId, or how to link records to projects
|
|
||||||
- Never include projectId or project creation logic in the generated config
|
|
||||||
- Keep asking questions until ≥ 90% confident, then output the JSON immediately
|
|
||||||
|
|
||||||
|
Rules for the generated template:
|
||||||
|
- Be explicit about field names (camelCase: title, status, priority, dueDate, projectId, content, etc.).
|
||||||
|
- Include concrete examples of mappings.
|
||||||
|
- Mention that Electron adds id/createdAt/updatedAt automatically.
|
||||||
|
- Set isAiSuggested: true and isApproved: false on every record.
|
||||||
{existing_section}\
|
{existing_section}\
|
||||||
Begin by exploring the directory, then ask your first question.\
|
Do not ask more than {max_turns} questions total. Start with your first question now.\
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _build_system_prompt(
|
def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
|
||||||
directory: str,
|
source_description = (
|
||||||
data_types: list[str],
|
"files in local directories" if agent_type == "local" else "emails and messages from cloud providers"
|
||||||
existing_config: str | None = None,
|
)
|
||||||
) -> tuple[str, Any]:
|
|
||||||
"""Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``."""
|
|
||||||
existing_section = (
|
existing_section = (
|
||||||
"\nThe user already has the following AgentConfig — refine it based on their answers:\n"
|
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
||||||
f"```json\n{existing_config}\n```\n"
|
f"---\n{existing_template}\n---\n"
|
||||||
if existing_config
|
if existing_template
|
||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
template, prompt_obj = get_prompt_or_fallback(
|
return _SYSTEM_PROMPT_TEMPLATE.format(
|
||||||
"journey_system", _JOURNEY_SYSTEM_PROMPT
|
source_description=source_description,
|
||||||
)
|
template_start=_TEMPLATE_START,
|
||||||
compiled = compile_prompt(
|
template_end=_TEMPLATE_END,
|
||||||
template,
|
|
||||||
prompt_obj,
|
|
||||||
directory=directory,
|
|
||||||
data_types=", ".join(data_types),
|
|
||||||
data_types_json=json.dumps(data_types),
|
|
||||||
config_start=_CONFIG_START,
|
|
||||||
config_end=_CONFIG_END,
|
|
||||||
existing_section=existing_section,
|
existing_section=existing_section,
|
||||||
|
max_turns=_MAX_TURNS,
|
||||||
)
|
)
|
||||||
return compiled, prompt_obj
|
|
||||||
|
|
||||||
|
|
||||||
# ── AgentConfig extraction ────────────────────────────────────────────────
|
def _first_question(agent_type: str) -> str:
|
||||||
|
return _LOCAL_PREAMBLE if agent_type == "local" else _CLOUD_PREAMBLE
|
||||||
|
|
||||||
|
|
||||||
def _extract_agent_config(text: str) -> str | None:
|
# ── Template extraction ───────────────────────────────────────────────────
|
||||||
"""Return validated AgentConfig JSON string from between markers, or None.
|
|
||||||
|
|
||||||
Parses the JSON with Pydantic to ensure it conforms to the schema before
|
|
||||||
returning. Returns None if markers are absent or JSON is invalid.
|
def _extract_template(text: str) -> str | None:
|
||||||
"""
|
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
|
||||||
if _CONFIG_START not in text or _CONFIG_END not in text:
|
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
|
||||||
return None
|
|
||||||
start_idx = text.index(_CONFIG_START) + len(_CONFIG_START)
|
|
||||||
end_idx = text.index(_CONFIG_END)
|
|
||||||
raw = text[start_idx:end_idx].strip()
|
|
||||||
if not raw:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
parsed = AgentConfig.model_validate_json(raw)
|
|
||||||
return parsed.model_dump_json()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("agent_setup: failed to parse AgentConfig JSON: %s", exc)
|
|
||||||
return None
|
return None
|
||||||
|
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
|
||||||
|
end_idx = text.index(_TEMPLATE_END)
|
||||||
|
return text[start_idx:end_idx].strip() or None
|
||||||
|
|
||||||
|
|
||||||
# ── LLM call with tool support ───────────────────────────────────────────
|
# ── LLM call ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _as_text(content: Any) -> str:
|
async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
|
||||||
if content is None:
|
"""Build LangChain messages from history and invoke the LLM."""
|
||||||
return ""
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
parts: list[str] = []
|
|
||||||
for item in content:
|
|
||||||
if isinstance(item, str):
|
|
||||||
parts.append(item)
|
|
||||||
elif isinstance(item, dict):
|
|
||||||
text = item.get("text")
|
|
||||||
if isinstance(text, str):
|
|
||||||
parts.append(text)
|
|
||||||
return "".join(parts)
|
|
||||||
return str(content)
|
|
||||||
|
|
||||||
|
|
||||||
async def _call_llm_with_tools(
|
|
||||||
system_prompt: str,
|
|
||||||
history: list[dict[str, Any]],
|
|
||||||
tools: list[Any],
|
|
||||||
*,
|
|
||||||
user_id: str = "",
|
|
||||||
session_id: str = "",
|
|
||||||
langfuse_prompt: Any = None,
|
|
||||||
) -> str:
|
|
||||||
"""Build LangChain messages from history and invoke the LLM with tools.
|
|
||||||
|
|
||||||
Handles tool-calling loops: if the LLM calls tools, execute them and
|
|
||||||
continue until a final text response is produced.
|
|
||||||
"""
|
|
||||||
lf = get_langfuse()
|
|
||||||
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||||
for turn in history:
|
for turn in history:
|
||||||
if turn["role"] == "user":
|
if turn["role"] == "user":
|
||||||
@@ -258,238 +181,137 @@ async def _call_llm_with_tools(
|
|||||||
messages.append(AIMessage(content=turn["content"]))
|
messages.append(AIMessage(content=turn["content"]))
|
||||||
|
|
||||||
llm = get_llm(model=None, temperature=0.4)
|
llm = get_llm(model=None, temperature=0.4)
|
||||||
llm_with_tools = llm.bind_tools(tools)
|
response = await llm.ainvoke(messages)
|
||||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
return response.content # type: ignore[return-value]
|
||||||
|
|
||||||
_span_ctx = (
|
|
||||||
lf.start_as_current_observation(
|
|
||||||
as_type="span",
|
|
||||||
name="journey-setup",
|
|
||||||
metadata={"user_id": user_id or None, "session_id": session_id or None},
|
|
||||||
input=history[-1]["content"] if history else "",
|
|
||||||
)
|
|
||||||
if lf else None
|
|
||||||
)
|
|
||||||
_span = _span_ctx.__enter__() if _span_ctx else None
|
|
||||||
|
|
||||||
try:
|
|
||||||
for _ in range(_MAX_TOOL_STEPS):
|
|
||||||
_gen_ctx = (
|
|
||||||
lf.start_as_current_observation(
|
|
||||||
as_type="generation",
|
|
||||||
name="journey-setup-llm",
|
|
||||||
model=settings.LLM_MODEL,
|
|
||||||
prompt=langfuse_prompt,
|
|
||||||
input=messages,
|
|
||||||
)
|
|
||||||
if lf else None
|
|
||||||
)
|
|
||||||
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
||||||
if _gen_ctx:
|
|
||||||
_gen.update(output=_as_text(response.content), usage=extract_usage(response))
|
|
||||||
_gen_ctx.__exit__(None, None, None)
|
|
||||||
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
if _span:
|
|
||||||
_span.update(output=_as_text(response.content))
|
|
||||||
return _as_text(response.content)
|
|
||||||
|
|
||||||
for call in response.tool_calls:
|
|
||||||
call_name = str(call.get("name", ""))
|
|
||||||
call_args = call.get("args", {})
|
|
||||||
logger.info(
|
|
||||||
"agent_setup: journey tool_call name=%s args=%s",
|
|
||||||
call_name,
|
|
||||||
json.dumps(call_args, ensure_ascii=True)[:500],
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_fn = tool_map.get(call_name)
|
|
||||||
if tool_fn is None:
|
|
||||||
tool_output = f"Unknown tool: {call_name}"
|
|
||||||
else:
|
|
||||||
tool_output = await tool_fn.ainvoke(call_args)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"agent_setup: journey tool_result name=%s output=%s",
|
|
||||||
call_name,
|
|
||||||
str(tool_output)[:800],
|
|
||||||
)
|
|
||||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
||||||
|
|
||||||
# Fallback: exceeded max steps.
|
|
||||||
final = await llm.ainvoke(messages)
|
|
||||||
final_text = _as_text(final.content)
|
|
||||||
if _span:
|
|
||||||
_span.update(output=final_text)
|
|
||||||
return final_text
|
|
||||||
finally:
|
|
||||||
if _span_ctx:
|
|
||||||
_span_ctx.__exit__(None, None, None)
|
|
||||||
if lf:
|
|
||||||
lf.flush()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Journey handlers (called from device_ws.py) ──────────────────────────
|
# ── Existing-config loader ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def handle_journey_start(
|
async def _load_existing_template(
|
||||||
|
agent_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
frame: dict[str, Any],
|
db: AsyncSession,
|
||||||
) -> dict[str, Any]:
|
) -> str | None:
|
||||||
"""Handle a ``journey_start`` WS frame.
|
"""Return the prompt_template of an existing agent config, or None."""
|
||||||
|
# Try local first, then cloud.
|
||||||
|
local_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(
|
||||||
|
LocalAgentConfig.id == agent_id,
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
local = local_result.scalar_one_or_none()
|
||||||
|
if local is not None:
|
||||||
|
return local.prompt_template
|
||||||
|
|
||||||
Creates a session, runs the setup LLM with directory exploration,
|
cloud_result = await db.execute(
|
||||||
and returns the ``journey_reply`` payload.
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.id == agent_id,
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cloud = cloud_result.scalar_one_or_none()
|
||||||
|
return cloud.prompt_template if cloud is not None else None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/start", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
||||||
|
async def start_journey(
|
||||||
|
body: JourneyStartRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> JourneyResponse:
|
||||||
|
"""Start a new Chatbot Journey session.
|
||||||
|
|
||||||
|
If ``agent_id`` is provided the session is pre-seeded with the existing
|
||||||
|
agent's ``prompt_template`` so the user can refine it.
|
||||||
"""
|
"""
|
||||||
agent_type = frame.get("agent_type", "local")
|
# Load existing template (may be None).
|
||||||
directory = frame.get("directory", "")
|
existing_template: str | None = None
|
||||||
data_types = frame.get("data_types", [])
|
if body.agent_id:
|
||||||
existing_config = frame.get("existing_config")
|
existing_template = await _load_existing_template(body.agent_id, current_user.id, db)
|
||||||
|
# If agent_id was given but not found, proceed without seeding (don't 404 —
|
||||||
|
# the user may be starting a fresh journey for a not-yet-persisted config).
|
||||||
|
|
||||||
# Use the session_id provided by the FE so the reply matches the
|
system_prompt = _build_system_prompt(body.agent_type, existing_template)
|
||||||
# listener key; fall back to a generated one if absent.
|
first_question = _first_question(body.agent_type)
|
||||||
session_id = frame.get("session_id") or str(uuid.uuid4())
|
|
||||||
system_prompt, langfuse_prompt = _build_system_prompt(directory, data_types, existing_config)
|
|
||||||
|
|
||||||
session = JourneySession(
|
session_id = str(uuid.uuid4())
|
||||||
|
session = _JourneySession(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=user_id,
|
user_id=current_user.id,
|
||||||
agent_type=agent_type,
|
agent_type=body.agent_type,
|
||||||
directory=directory,
|
# Seed history with the AI's first question so it stays consistent.
|
||||||
data_types=data_types,
|
history=[{"role": "assistant", "content": first_question}],
|
||||||
system_prompt=system_prompt,
|
|
||||||
langfuse_prompt=langfuse_prompt,
|
|
||||||
)
|
)
|
||||||
|
# Store the system prompt inside the session for reuse in /message.
|
||||||
# Seed with an initial user message — some providers require at least one
|
session.__dict__["_system_prompt"] = system_prompt # type: ignore[index]
|
||||||
# user/input message to be present.
|
|
||||||
seed_history: list[dict[str, Any]] = [
|
|
||||||
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
|
||||||
]
|
|
||||||
ai_reply = await _call_llm_with_tools(
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
history=seed_history,
|
|
||||||
tools=list(FILESYSTEM_TOOLS),
|
|
||||||
user_id=user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
langfuse_prompt=langfuse_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
session.history.extend(seed_history)
|
|
||||||
session.history.append({"role": "assistant", "content": ai_reply})
|
|
||||||
_sessions[session_id] = session
|
_sessions[session_id] = session
|
||||||
|
|
||||||
logger.info(
|
logger.info("Journey session %s started for user %s (agent_type=%s)", session_id, current_user.id, body.agent_type)
|
||||||
"agent_setup: journey session %s started for user %s (directory=%s)",
|
return JourneyResponse(session_id=session_id, message=first_question, done=False)
|
||||||
session_id,
|
|
||||||
user_id,
|
|
||||||
directory,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if the LLM produced the config on the first turn (unlikely but possible).
|
|
||||||
agent_config = _extract_agent_config(ai_reply)
|
|
||||||
done = agent_config is not None
|
|
||||||
|
|
||||||
display_message = ai_reply
|
|
||||||
if done:
|
|
||||||
display_message = (
|
|
||||||
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
|
|
||||||
or "Here is your agent configuration. You can save it or continue refining."
|
|
||||||
)
|
|
||||||
_sessions.pop(session_id, None)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"type": "journey_reply",
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": display_message,
|
|
||||||
"done": done,
|
|
||||||
"agent_config": agent_config,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_journey_message(
|
@router.post("/message", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
||||||
user_id: str,
|
async def send_journey_message(
|
||||||
frame: dict[str, Any],
|
body: JourneyMessageRequest,
|
||||||
) -> dict[str, Any]:
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
"""Handle a ``journey_message`` WS frame.
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> JourneyResponse:
|
||||||
|
"""Send a message in an existing Chatbot Journey session.
|
||||||
|
|
||||||
Appends the user message, calls the LLM, and returns the
|
The server appends the user's message to the conversation history,
|
||||||
``journey_reply`` payload.
|
calls the LLM, and appends the AI reply. When the LLM wraps up with a
|
||||||
|
``prompt_template`` block the response includes ``done=True`` and the
|
||||||
|
extracted template.
|
||||||
"""
|
"""
|
||||||
session_id = frame.get("session_id", "")
|
session = _get_session(body.session_id, current_user.id)
|
||||||
message = frame.get("message", "")
|
system_prompt: str = session.__dict__.get("_system_prompt", _build_system_prompt(session.agent_type, None)) # type: ignore[assignment]
|
||||||
|
|
||||||
session = get_journey_session(session_id, user_id)
|
# Append user turn to history.
|
||||||
if session is None:
|
session.history.append({"role": "user", "content": body.message})
|
||||||
return {
|
|
||||||
"type": "journey_reply",
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": "Journey session not found or expired. Please start a new setup.",
|
|
||||||
"done": True,
|
|
||||||
"agent_config": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Append user turn.
|
# Call the LLM with the full conversation so far.
|
||||||
session.history.append({"role": "user", "content": message})
|
ai_reply = await _call_llm(system_prompt, session.history)
|
||||||
|
|
||||||
# Call the LLM with tools.
|
|
||||||
ai_reply = await _call_llm_with_tools(
|
|
||||||
system_prompt=session.system_prompt,
|
|
||||||
history=session.history,
|
|
||||||
tools=list(FILESYSTEM_TOOLS),
|
|
||||||
user_id=session.user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
langfuse_prompt=session.langfuse_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Append AI turn.
|
||||||
session.history.append({"role": "assistant", "content": ai_reply})
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
|
||||||
# Check if the LLM produced the final config.
|
# Check if the LLM produced the final template.
|
||||||
agent_config = _extract_agent_config(ai_reply)
|
prompt_template = _extract_template(ai_reply)
|
||||||
done = agent_config is not None
|
done = prompt_template is not None
|
||||||
|
|
||||||
# If the LLM didn't produce a config, nudge it once it hits the hard safety cap.
|
|
||||||
if not done:
|
|
||||||
turns = sum(1 for t in session.history if t["role"] == "user")
|
|
||||||
if turns >= _MAX_TURNS:
|
|
||||||
nudge_content = (
|
|
||||||
"[System: You have enough information. Please generate the final "
|
|
||||||
f"AgentConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]"
|
|
||||||
)
|
|
||||||
session.history.append({"role": "user", "content": nudge_content})
|
|
||||||
|
|
||||||
nudge_reply = await _call_llm_with_tools(
|
|
||||||
system_prompt=session.system_prompt,
|
|
||||||
history=session.history,
|
|
||||||
tools=list(FILESYSTEM_TOOLS),
|
|
||||||
user_id=session.user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
langfuse_prompt=session.langfuse_prompt,
|
|
||||||
)
|
|
||||||
session.history.append({"role": "assistant", "content": nudge_reply})
|
|
||||||
|
|
||||||
agent_config = _extract_agent_config(nudge_reply)
|
|
||||||
if agent_config is not None:
|
|
||||||
done = True
|
|
||||||
ai_reply = nudge_reply
|
|
||||||
|
|
||||||
|
# Strip the sentinel markers from the message shown to the user.
|
||||||
display_message = ai_reply
|
display_message = ai_reply
|
||||||
if done:
|
if done:
|
||||||
display_message = (
|
display_message = (
|
||||||
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
if _CONFIG_START in ai_reply
|
or "Here is your agent configuration. You can save it or continue refining."
|
||||||
else "Here is your agent configuration. You can save it or continue refining."
|
|
||||||
)
|
)
|
||||||
_sessions.pop(session_id, None)
|
|
||||||
logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id)
|
|
||||||
|
|
||||||
return {
|
if done:
|
||||||
"type": "journey_reply",
|
logger.info("Journey session %s completed for user %s", body.session_id, current_user.id)
|
||||||
"session_id": session_id,
|
# Clean up the session immediately on completion.
|
||||||
"message": display_message,
|
_sessions.pop(body.session_id, None)
|
||||||
"done": done,
|
else:
|
||||||
"agent_config": agent_config,
|
# Nudge the LLM to wrap up after max turns.
|
||||||
}
|
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||||
|
if turns >= _MAX_TURNS:
|
||||||
|
# Add a system-level nudge as a hidden user message.
|
||||||
|
session.history.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"[System: You have enough information. Please generate the final "
|
||||||
|
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
|
return JourneyResponse(
|
||||||
|
session_id=body.session_id,
|
||||||
|
message=display_message,
|
||||||
|
done=done,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,36 +1,45 @@
|
|||||||
"""Agent routes.
|
"""Agent CRUD routes: local directory agents and cloud connector agents.
|
||||||
|
|
||||||
Backend responsibilities are intentionally minimal:
|
Endpoints:
|
||||||
GET /agents/catalog — static catalog for UI display
|
GET /agents/catalog — hardcoded agent type catalog
|
||||||
POST /agents/can-create — billing eligibility check
|
GET /agents/local — list user's local agent configs
|
||||||
POST /agents/trigger — trigger a local agent run
|
POST /agents/local — create local agent (tier-gated)
|
||||||
|
PUT /agents/local/{agent_id} — partial update (ownership check)
|
||||||
Agent configuration is owned by the Electron app and is not persisted
|
DELETE /agents/local/{agent_id} — delete + cascade run logs
|
||||||
in backend agent-config tables.
|
GET /agents/cloud — list user's cloud agent configs
|
||||||
|
POST /agents/cloud — create cloud agent (tier-gated)
|
||||||
|
PUT /agents/cloud/{agent_id} — partial update (ownership check)
|
||||||
|
DELETE /agents/cloud/{agent_id} — delete + cascade run logs
|
||||||
|
GET /agents/runs — paginated run logs (agent_id, page, limit)
|
||||||
|
POST /agents/{agent_id}/run — manual trigger stub (dispatch in Step 3.4)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
from datetime import datetime
|
||||||
from datetime import datetime, timedelta, timezone
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
from sqlalchemy import func, select
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import func, or_, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.billing.tier_manager import FEATURES
|
from app.billing.tier_manager import FEATURES
|
||||||
from app.core.agent_runner import is_agent_running, run_local_agent
|
from app.core.agent_runner import run_cloud_agent, run_local_agent
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.models import AgentRunLog, LocalAgentConfig
|
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
AgentCatalogItem,
|
AgentCatalogItem,
|
||||||
AgentCreationCheckRequest,
|
|
||||||
AgentCreationCheckResponse,
|
|
||||||
AgentRunLogResponse,
|
AgentRunLogResponse,
|
||||||
AgentTriggerRequest,
|
CloudAgentConfigCreate,
|
||||||
|
CloudAgentConfigResponse,
|
||||||
|
CloudAgentConfigUpdate,
|
||||||
|
LocalAgentConfigCreate,
|
||||||
|
LocalAgentConfigResponse,
|
||||||
|
LocalAgentConfigUpdate,
|
||||||
UserProfile,
|
UserProfile,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -47,21 +56,39 @@ def _dt_ms_opt(dt: datetime | None) -> int | None:
|
|||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
def _to_data_types(values: list[str]) -> list[str]:
|
# ── Model → schema converters ─────────────────────────────────────────
|
||||||
normalize = {
|
|
||||||
"task": "tasks", "tasks": "tasks",
|
def _to_local_response(a: LocalAgentConfig) -> LocalAgentConfigResponse:
|
||||||
"note": "notes", "notes": "notes",
|
return LocalAgentConfigResponse(
|
||||||
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
id=a.id,
|
||||||
"project": "projects", "projects": "projects",
|
name=a.name,
|
||||||
}
|
device_id=a.device_id,
|
||||||
seen: set[str] = set()
|
directory_paths=a.directory_paths,
|
||||||
result: list[str] = []
|
data_types=a.data_types,
|
||||||
for v in values:
|
prompt_template=a.prompt_template,
|
||||||
mapped = normalize.get(v)
|
file_extensions=a.file_extensions,
|
||||||
if mapped and mapped not in seen:
|
schedule_cron=a.schedule_cron,
|
||||||
seen.add(mapped)
|
enabled=a.enabled,
|
||||||
result.append(mapped)
|
last_run_at=_dt_ms_opt(a.last_run_at),
|
||||||
return result
|
created_at=_dt_ms(a.created_at),
|
||||||
|
updated_at=_dt_ms(a.updated_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_cloud_response(a: CloudAgentConfig) -> CloudAgentConfigResponse:
|
||||||
|
return CloudAgentConfigResponse(
|
||||||
|
id=a.id,
|
||||||
|
provider=a.provider, # type: ignore[arg-type]
|
||||||
|
name=a.name,
|
||||||
|
data_types=a.data_types,
|
||||||
|
prompt_template=a.prompt_template,
|
||||||
|
schedule_cron=a.schedule_cron,
|
||||||
|
filter_config=a.filter_config,
|
||||||
|
enabled=a.enabled,
|
||||||
|
last_run_at=_dt_ms_opt(a.last_run_at),
|
||||||
|
created_at=_dt_ms(a.created_at),
|
||||||
|
updated_at=_dt_ms(a.updated_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
||||||
@@ -78,42 +105,77 @@ def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
# ── Ownership-checked lookups ─────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _get_local_agent_for_user(
|
||||||
|
agent_id: str, user_id: str, db: AsyncSession
|
||||||
|
) -> LocalAgentConfig:
|
||||||
|
result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(
|
||||||
|
LocalAgentConfig.id == agent_id,
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_cloud_agent_for_user(
|
||||||
|
agent_id: str, user_id: str, db: AsyncSession
|
||||||
|
) -> CloudAgentConfig:
|
||||||
|
result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.id == agent_id,
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tier limit helper ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _count_enabled_agents(user_id: str, db: AsyncSession) -> int:
|
||||||
|
"""Return combined enabled local + cloud agent count for the user."""
|
||||||
|
local_count = (
|
||||||
|
await db.execute(
|
||||||
|
select(func.count(LocalAgentConfig.id)).where(
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
LocalAgentConfig.enabled == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
cloud_count = (
|
||||||
|
await db.execute(
|
||||||
|
select(func.count(CloudAgentConfig.id)).where(
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
CloudAgentConfig.enabled == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
return local_count + cloud_count
|
||||||
|
|
||||||
|
|
||||||
|
def _enforce_agent_limit(tier: str, current_count: int) -> None:
|
||||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||||
if limit != -1 and current_count >= limit:
|
if limit != -1 and current_count >= limit:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||||
)
|
)
|
||||||
return limit
|
|
||||||
|
|
||||||
|
|
||||||
async def _enforce_run_frequency(
|
# ── Local page schema (used by runs endpoint) ─────────────────────────
|
||||||
tier: str,
|
|
||||||
user_id: str,
|
|
||||||
db: AsyncSession,
|
|
||||||
) -> None:
|
|
||||||
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
|
|
||||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
|
||||||
if limit == -1:
|
|
||||||
return # unlimited
|
|
||||||
|
|
||||||
today_start = datetime.now(timezone.utc).replace(
|
class _RunsPage(BaseModel):
|
||||||
hour=0, minute=0, second=0, microsecond=0
|
total: int
|
||||||
)
|
page: int
|
||||||
result = await db.execute(
|
limit: int
|
||||||
select(func.count(AgentRunLog.id)).where(
|
items: list[AgentRunLogResponse]
|
||||||
AgentRunLog.user_id == user_id,
|
|
||||||
AgentRunLog.started_at >= today_start,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
runs_today: int = result.scalar_one()
|
|
||||||
|
|
||||||
if runs_today >= limit:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Catalog ───────────────────────────────────────────────────────────
|
# ── Catalog ───────────────────────────────────────────────────────────
|
||||||
@@ -147,61 +209,229 @@ async def get_agent_catalog(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@router.post("/can-create", response_model=AgentCreationCheckResponse)
|
# ── Local agent CRUD ──────────────────────────────────────────────────
|
||||||
async def can_create_agent(
|
|
||||||
body: AgentCreationCheckRequest,
|
@router.get("/local", response_model=list[LocalAgentConfigResponse])
|
||||||
|
async def list_local_agents(
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
) -> AgentCreationCheckResponse:
|
db: AsyncSession = Depends(get_session),
|
||||||
"""Check if the user can create one more agent based on billing tier.
|
) -> list[LocalAgentConfigResponse]:
|
||||||
|
"""List all local directory agent configs owned by the authenticated user."""
|
||||||
Since configuration is client-owned, the Electron app sends its current
|
result = await db.execute(
|
||||||
active agent count and the backend applies tier limits.
|
select(LocalAgentConfig).where(LocalAgentConfig.user_id == current_user.id)
|
||||||
"""
|
|
||||||
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
|
|
||||||
allowed = limit == -1 or body.active_agents < limit
|
|
||||||
return AgentCreationCheckResponse(
|
|
||||||
allowed=allowed,
|
|
||||||
tier=current_user.tier,
|
|
||||||
active_agents=body.active_agents,
|
|
||||||
limit=limit,
|
|
||||||
)
|
)
|
||||||
|
return [_to_local_response(a) for a in result.scalars().all()]
|
||||||
|
|
||||||
|
|
||||||
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
@router.post("/local", response_model=LocalAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_local_agent(
|
||||||
|
body: LocalAgentConfigCreate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> LocalAgentConfigResponse:
|
||||||
|
"""Create a new local directory agent config.
|
||||||
|
|
||||||
|
The combined count of enabled local and cloud agents for the user is
|
||||||
|
checked against the ``batch_active`` limit for their billing tier.
|
||||||
|
"""
|
||||||
|
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
||||||
|
agent = LocalAgentConfig(
|
||||||
|
user_id=current_user.id,
|
||||||
|
name=body.name,
|
||||||
|
device_id=body.device_id,
|
||||||
|
directory_paths=body.directory_paths,
|
||||||
|
data_types=body.data_types,
|
||||||
|
prompt_template=body.prompt_template,
|
||||||
|
file_extensions=body.file_extensions,
|
||||||
|
schedule_cron=body.schedule_cron,
|
||||||
|
)
|
||||||
|
db.add(agent)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_local_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/local/{agent_id}", response_model=LocalAgentConfigResponse)
|
||||||
|
async def update_local_agent(
|
||||||
|
agent_id: str,
|
||||||
|
body: LocalAgentConfigUpdate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> LocalAgentConfigResponse:
|
||||||
|
"""Partially update a local agent config. Only provided fields are changed."""
|
||||||
|
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
for field, value in body.model_dump(exclude_unset=True).items():
|
||||||
|
setattr(agent, field, value)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_local_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/local/{agent_id}", response_model=dict)
|
||||||
|
async def delete_local_agent(
|
||||||
|
agent_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a local agent config. Associated run logs are cascade-deleted."""
|
||||||
|
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
await db.delete(agent)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud agent CRUD ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/cloud", response_model=list[CloudAgentConfigResponse])
|
||||||
|
async def list_cloud_agents(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[CloudAgentConfigResponse]:
|
||||||
|
"""List all cloud connector agent configs owned by the authenticated user."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
return [_to_cloud_response(a) for a in result.scalars().all()]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/cloud", response_model=CloudAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_cloud_agent(
|
||||||
|
body: CloudAgentConfigCreate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> CloudAgentConfigResponse:
|
||||||
|
"""Create a new cloud connector agent config.
|
||||||
|
|
||||||
|
The combined count of enabled local and cloud agents for the user is
|
||||||
|
checked against the ``batch_active`` limit for their billing tier.
|
||||||
|
"""
|
||||||
|
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
||||||
|
agent = CloudAgentConfig(
|
||||||
|
user_id=current_user.id,
|
||||||
|
provider=body.provider,
|
||||||
|
name=body.name,
|
||||||
|
data_types=body.data_types,
|
||||||
|
prompt_template=body.prompt_template,
|
||||||
|
oauth_token_encrypted=body.oauth_token_encrypted,
|
||||||
|
schedule_cron=body.schedule_cron,
|
||||||
|
filter_config=body.filter_config,
|
||||||
|
)
|
||||||
|
db.add(agent)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_cloud_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/cloud/{agent_id}", response_model=CloudAgentConfigResponse)
|
||||||
|
async def update_cloud_agent(
|
||||||
|
agent_id: str,
|
||||||
|
body: CloudAgentConfigUpdate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> CloudAgentConfigResponse:
|
||||||
|
"""Partially update a cloud agent config. Only provided fields are changed."""
|
||||||
|
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
for field, value in body.model_dump(exclude_unset=True).items():
|
||||||
|
setattr(agent, field, value)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_cloud_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/cloud/{agent_id}", response_model=dict)
|
||||||
|
async def delete_cloud_agent(
|
||||||
|
agent_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a cloud agent config. Associated run logs are cascade-deleted."""
|
||||||
|
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
await db.delete(agent)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Run logs ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/runs", response_model=_RunsPage)
|
||||||
|
async def list_run_logs(
|
||||||
|
agent_id: str | None = Query(default=None),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
limit: int = Query(default=20, ge=1, le=100),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> _RunsPage:
|
||||||
|
"""Return paginated run logs for the authenticated user.
|
||||||
|
|
||||||
|
Optionally filter by ``agent_id``. Results are ordered from newest to oldest.
|
||||||
|
"""
|
||||||
|
base_filter = [AgentRunLog.user_id == current_user.id]
|
||||||
|
if agent_id:
|
||||||
|
base_filter.append(AgentRunLog.agent_id == agent_id)
|
||||||
|
|
||||||
|
total = (
|
||||||
|
await db.execute(select(func.count(AgentRunLog.id)).where(*base_filter))
|
||||||
|
).scalar_one()
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(AgentRunLog)
|
||||||
|
.where(*base_filter)
|
||||||
|
.order_by(AgentRunLog.started_at.desc())
|
||||||
|
.offset((page - 1) * limit)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
items = [_to_run_log_response(log) for log in result.scalars().all()]
|
||||||
|
|
||||||
|
return _RunsPage(total=total, page=page, limit=limit, items=items)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Manual trigger stub ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/{agent_id}/run", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
||||||
async def trigger_agent_run(
|
async def trigger_agent_run(
|
||||||
body: AgentTriggerRequest,
|
agent_id: str,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_session),
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> AgentRunLogResponse:
|
) -> AgentRunLogResponse:
|
||||||
"""Trigger a local agent run using client-provided configuration."""
|
"""Manually trigger an agent run.
|
||||||
_enforce_agent_limit(current_user.tier, body.active_agents)
|
|
||||||
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
|
||||||
|
|
||||||
config = LocalAgentConfig(
|
Looks up the agent config (local or cloud) by ID with ownership check,
|
||||||
id=str(uuid.uuid4()),
|
creates a run log entry with ``status="running"``, and returns it.
|
||||||
user_id=current_user.id,
|
|
||||||
device_id=body.device_id,
|
Actual dispatch to the agent runner is wired in Step 3.4 once
|
||||||
name="Local Directory Monitor",
|
``DeviceConnectionManager`` and ``agent_runner`` are available.
|
||||||
directory_paths=[body.directory],
|
"""
|
||||||
data_types=_to_data_types(body.what_to_extract),
|
# Determine agent type by trying local first, then cloud.
|
||||||
prompt_template=body.custom_agent_prompt,
|
# Keep the full config object so we can pass it to the agent runner.
|
||||||
file_extensions=[],
|
local_config: LocalAgentConfig | None = None
|
||||||
schedule_cron=body.batch_interval,
|
cloud_config: CloudAgentConfig | None = None
|
||||||
enabled=True,
|
|
||||||
|
local_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(
|
||||||
|
LocalAgentConfig.id == agent_id,
|
||||||
|
LocalAgentConfig.user_id == current_user.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
|
||||||
stable_agent_id = body.agent_id or config.id
|
|
||||||
|
|
||||||
if is_agent_running(stable_agent_id):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
|
||||||
detail="Agent is already running. Only one run per agent is allowed at a time.",
|
|
||||||
)
|
)
|
||||||
|
local_config = local_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if local_config is not None:
|
||||||
|
agent_type = "local"
|
||||||
|
else:
|
||||||
|
cloud_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.id == agent_id,
|
||||||
|
CloudAgentConfig.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cloud_config = cloud_result.scalar_one_or_none()
|
||||||
|
if cloud_config is not None:
|
||||||
|
agent_type = "cloud"
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||||
|
|
||||||
run_log = AgentRunLog(
|
run_log = AgentRunLog(
|
||||||
agent_id=stable_agent_id,
|
agent_id=agent_id,
|
||||||
agent_type="local",
|
agent_type=agent_type,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
status="running",
|
status="running",
|
||||||
)
|
)
|
||||||
@@ -209,14 +439,14 @@ async def trigger_agent_run(
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(run_log)
|
await db.refresh(run_log)
|
||||||
|
|
||||||
run_context = {
|
# Dispatch the run as a background task — returns 202 immediately.
|
||||||
"type": "agent_batch",
|
if agent_type == "local" and local_config is not None:
|
||||||
"run_id": run_log.id,
|
|
||||||
"agent_id": stable_agent_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
|
run_local_agent(current_user.id, local_config, run_log, device_manager)
|
||||||
|
)
|
||||||
|
elif agent_type == "cloud" and cloud_config is not None:
|
||||||
|
asyncio.create_task(
|
||||||
|
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
|
||||||
)
|
)
|
||||||
|
|
||||||
return _to_run_log_response(run_log)
|
return _to_run_log_response(run_log)
|
||||||
|
|||||||
@@ -10,7 +10,9 @@ from fastapi.responses import JSONResponse
|
|||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.core.deep_agent import run_home
|
from app.core.deep_agent import run_home
|
||||||
from app.schemas import ChatRequest, UserProfile
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.db import async_session
|
||||||
|
from app.schemas import ChatRequest, ChatResponse, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
|
||||||
@@ -20,10 +22,21 @@ async def chat(
|
|||||||
body: ChatRequest,
|
body: ChatRequest,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
"""REST fallback for home chat when websocket streaming is unavailable."""
|
"""Route a chat message through the Home deep agent (non-streaming)."""
|
||||||
response = await run_home(
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(current_user.id, body.message)
|
||||||
|
|
||||||
|
context = {
|
||||||
|
**body.context.model_dump(),
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
response_text = await run_home(
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
message=body.message,
|
message=body.message,
|
||||||
context=body.context.model_dump(),
|
context=context,
|
||||||
|
db_session_factory=async_session,
|
||||||
)
|
)
|
||||||
return JSONResponse(content={"response": response})
|
result = ChatResponse(response=response_text)
|
||||||
|
return JSONResponse(content=result.model_dump())
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ Protocol:
|
|||||||
|
|
||||||
Incoming frame dispatch:
|
Incoming frame dispatch:
|
||||||
- ``tool_result`` → resolves a pending tool-call Future.
|
- ``tool_result`` → resolves a pending tool-call Future.
|
||||||
- ``journey_start`` → starts a guided setup journey session.
|
- ``agent_data`` → enqueued in the per-run agent data queue.
|
||||||
- ``journey_message`` → continues a journey conversation.
|
- ``agent_complete`` → sends None sentinel to close the queue stream.
|
||||||
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
||||||
- unknown types → logged, ignored.
|
- unknown types → logged, ignored.
|
||||||
|
|
||||||
@@ -39,13 +39,12 @@ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
|
||||||
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.core.agent_runner import trigger_pending_runs
|
from app.core.agent_runner import trigger_pending_runs
|
||||||
from app.core.deep_agent import run_floating_stream, run_home_stream
|
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.core.output_formatter import StreamFormatter
|
from app.core.deep_agent import run_home_stream, run_floating_stream
|
||||||
|
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
||||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog
|
from app.models import AgentRunLog
|
||||||
@@ -148,6 +147,37 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
"device_ws: tool_result missing id from user=%s", user_id
|
"device_ws: tool_result missing id from user=%s", user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.agent_data:
|
||||||
|
run_id = frame.get("run_id")
|
||||||
|
if run_id:
|
||||||
|
try:
|
||||||
|
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
||||||
|
await queue.put(frame)
|
||||||
|
except RuntimeError:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_data for unknown run user=%s run=%s",
|
||||||
|
user_id,
|
||||||
|
run_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_data missing run_id from user=%s", user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.agent_complete:
|
||||||
|
run_id = frame.get("run_id")
|
||||||
|
if run_id:
|
||||||
|
try:
|
||||||
|
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
||||||
|
# Sentinel: signals the agent data stream is finished.
|
||||||
|
await queue.put(None)
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_complete missing run_id from user=%s", user_id
|
||||||
|
)
|
||||||
|
|
||||||
elif frame_type == WsFrameType.home_request:
|
elif frame_type == WsFrameType.home_request:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
_handle_home_request(websocket, user_id, frame)
|
_handle_home_request(websocket, user_id, frame)
|
||||||
@@ -158,16 +188,6 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
_handle_floating_request(websocket, user_id, frame)
|
_handle_floating_request(websocket, user_id, frame)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif frame_type == WsFrameType.journey_start:
|
|
||||||
asyncio.create_task(
|
|
||||||
_handle_journey_start(websocket, user_id, frame)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.journey_message:
|
|
||||||
asyncio.create_task(
|
|
||||||
_handle_journey_message(websocket, user_id, frame)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == "pong":
|
elif frame_type == "pong":
|
||||||
# Heartbeat ack — nothing to do, connection is alive.
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
pass
|
pass
|
||||||
@@ -180,13 +200,35 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
|
|
||||||
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_WS_TOOL_CALL_TIMEOUT = 30 # seconds to wait for Electron tool_result
|
||||||
|
|
||||||
|
|
||||||
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
||||||
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
||||||
async def _executor(payload: dict) -> dict:
|
async def _executor(payload: dict) -> dict:
|
||||||
payload["type"] = WsFrameType.tool_call
|
payload["type"] = WsFrameType.tool_call
|
||||||
|
call_id = payload["id"]
|
||||||
|
logger.info("ws_executor: sending tool_call id=%s action=%s", call_id, payload.get("action"))
|
||||||
await websocket.send_text(json.dumps(payload))
|
await websocket.send_text(json.dumps(payload))
|
||||||
future = device_manager.create_pending_call(user_id, payload["id"])
|
future = device_manager.create_pending_call(user_id, call_id)
|
||||||
return await future
|
try:
|
||||||
|
result = await asyncio.wait_for(future, timeout=_WS_TOOL_CALL_TIMEOUT)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
"ws_executor: timeout waiting for tool_result id=%s action=%s user=%s",
|
||||||
|
call_id, payload.get("action"), user_id,
|
||||||
|
)
|
||||||
|
# Clean up the pending future so it doesn't leak
|
||||||
|
conn = device_manager._connections.get(user_id)
|
||||||
|
if conn:
|
||||||
|
conn.pending_calls.pop(call_id, None)
|
||||||
|
return {"error": f"Tool call timed out after {_WS_TOOL_CALL_TIMEOUT}s", "rows": []}
|
||||||
|
logger.info("ws_executor: tool_result id=%s result_type=%s result_keys=%s",
|
||||||
|
call_id, type(result).__name__,
|
||||||
|
list(result.keys()) if isinstance(result, dict) else "N/A")
|
||||||
|
if result is None:
|
||||||
|
logger.error("ws_executor: future resolved to None for call_id=%s user=%s", call_id, user_id)
|
||||||
|
return result
|
||||||
return _executor
|
return _executor
|
||||||
|
|
||||||
|
|
||||||
@@ -199,27 +241,14 @@ async def _handle_home_request(
|
|||||||
request_id = frame.get("request_id") or str(uuid4())
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
logger.info(
|
|
||||||
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
|
|
||||||
user_id,
|
|
||||||
request_id,
|
|
||||||
session_id,
|
|
||||||
message[:200],
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(
|
memory_context = await memory.enrich_context(user_id, message)
|
||||||
user_id,
|
|
||||||
message,
|
|
||||||
trace_id=request_id,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
context: dict = {
|
context: dict = {
|
||||||
"conversation_history": frame.get("conversation_history", []),
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
||||||
**memory_context,
|
**memory_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,11 +256,12 @@ async def _handle_home_request(
|
|||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
try:
|
try:
|
||||||
event_stream = run_home_stream(user_id, message, context)
|
event_stream = run_home_stream(
|
||||||
formatter = StreamFormatter(request_id=request_id)
|
user_id, message, context, db_session_factory=async_session
|
||||||
|
)
|
||||||
|
formatter = HomeFormatter(request_id=request_id)
|
||||||
async for ws_frame in formatter.format(event_stream):
|
async for ws_frame in formatter.format(event_stream):
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
# Collect text chunks to build the full response for episode storage
|
|
||||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -246,14 +276,7 @@ async def _handle_home_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
user_id, session_id, message, "".join(response_chunks)
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
|
|
||||||
user_id,
|
|
||||||
request_id,
|
|
||||||
session_id,
|
|
||||||
len("".join(response_chunks)),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -267,37 +290,23 @@ async def _handle_floating_request(
|
|||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
scope: dict = frame.get("scope", {})
|
scope: dict = frame.get("scope", {})
|
||||||
logger.info(
|
|
||||||
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
|
|
||||||
user_id,
|
|
||||||
request_id,
|
|
||||||
session_id,
|
|
||||||
json.dumps(scope, ensure_ascii=True)[:200],
|
|
||||||
message[:200],
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(
|
memory_context = await memory.enrich_context(user_id, message)
|
||||||
user_id,
|
|
||||||
message,
|
|
||||||
trace_id=request_id,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
context: dict = {
|
context: dict = {"scope": scope, **memory_context}
|
||||||
"scope": scope,
|
|
||||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
||||||
**memory_context,
|
|
||||||
}
|
|
||||||
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
try:
|
try:
|
||||||
event_stream = run_floating_stream(user_id, message, context)
|
event_stream = run_floating_stream(
|
||||||
formatter = StreamFormatter(request_id=request_id)
|
user_id, message, context, scope=scope,
|
||||||
|
db_session_factory=async_session,
|
||||||
|
)
|
||||||
|
formatter = FloatingFormatter(request_id=request_id)
|
||||||
async for ws_frame in formatter.format(event_stream):
|
async for ws_frame in formatter.format(event_stream):
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
@@ -314,72 +323,8 @@ async def _handle_floating_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
user_id, session_id, message, "".join(response_chunks)
|
||||||
)
|
)
|
||||||
logger.info(
|
|
||||||
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
|
|
||||||
user_id,
|
|
||||||
request_id,
|
|
||||||
session_id,
|
|
||||||
len("".join(response_chunks)),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_journey_start(
|
|
||||||
websocket: WebSocket,
|
|
||||||
user_id: str,
|
|
||||||
frame: dict,
|
|
||||||
) -> None:
|
|
||||||
"""Handle a journey_start frame — explores directory and sends first question."""
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
|
||||||
set_client_executor(executor)
|
|
||||||
try:
|
|
||||||
reply = await handle_journey_start(user_id, frame)
|
|
||||||
await websocket.send_text(json.dumps(reply))
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(
|
|
||||||
"device_ws: journey_start failed user=%s: %s", user_id, exc
|
|
||||||
)
|
|
||||||
await websocket.send_text(json.dumps({
|
|
||||||
"type": "journey_reply",
|
|
||||||
"session_id": frame.get("session_id", ""),
|
|
||||||
"message": f"Failed to start journey: {exc}",
|
|
||||||
"done": True,
|
|
||||||
"prompt_template": None,
|
|
||||||
}))
|
|
||||||
finally:
|
|
||||||
clear_client_executor()
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_journey_message(
|
|
||||||
websocket: WebSocket,
|
|
||||||
user_id: str,
|
|
||||||
frame: dict,
|
|
||||||
) -> None:
|
|
||||||
"""Handle a journey_message frame — continues the journey conversation."""
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
|
||||||
set_client_executor(executor)
|
|
||||||
try:
|
|
||||||
reply = await handle_journey_message(user_id, frame)
|
|
||||||
await websocket.send_text(json.dumps(reply))
|
|
||||||
except Exception as exc:
|
|
||||||
session_id = frame.get("session_id", "")
|
|
||||||
logger.error(
|
|
||||||
"device_ws: journey_message failed user=%s session=%s: %s",
|
|
||||||
user_id, session_id, exc,
|
|
||||||
)
|
|
||||||
await websocket.send_text(json.dumps({
|
|
||||||
"type": "journey_reply",
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": f"Journey error: {exc}",
|
|
||||||
"done": True,
|
|
||||||
"prompt_template": None,
|
|
||||||
}))
|
|
||||||
finally:
|
|
||||||
clear_client_executor()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||||
@@ -415,3 +360,6 @@ async def _mark_runs_disconnected(user_id: str) -> None:
|
|||||||
user_id,
|
user_id,
|
||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"free": {
|
"free": {
|
||||||
"agents": 3,
|
"agents": 3,
|
||||||
"batch_active": 2,
|
"batch_active": 2,
|
||||||
"batch_runs_per_day": 5,
|
|
||||||
"cloud_storage_gb": 0,
|
"cloud_storage_gb": 0,
|
||||||
"backup_gb": 0,
|
"backup_gb": 0,
|
||||||
"providers": 1,
|
"providers": 1,
|
||||||
@@ -32,7 +31,6 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"pro": {
|
"pro": {
|
||||||
"agents": -1, # unlimited
|
"agents": -1, # unlimited
|
||||||
"batch_active": 10,
|
"batch_active": 10,
|
||||||
"batch_runs_per_day": 50,
|
|
||||||
"cloud_storage_gb": 5,
|
"cloud_storage_gb": 5,
|
||||||
"backup_gb": 5,
|
"backup_gb": 5,
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
@@ -43,7 +41,6 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"power": {
|
"power": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1, # unlimited
|
"batch_active": -1, # unlimited
|
||||||
"batch_runs_per_day": -1, # unlimited
|
|
||||||
"cloud_storage_gb": 25,
|
"cloud_storage_gb": 25,
|
||||||
"backup_gb": 25,
|
"backup_gb": 25,
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
@@ -54,7 +51,6 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"team": {
|
"team": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1,
|
"batch_active": -1,
|
||||||
"batch_runs_per_day": -1, # unlimited
|
|
||||||
"cloud_storage_gb": -1, # unlimited
|
"cloud_storage_gb": -1, # unlimited
|
||||||
"backup_gb": -1, # unlimited
|
"backup_gb": -1, # unlimited
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
@@ -81,18 +77,16 @@ class TierManager:
|
|||||||
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||||
"""Return the current billing tier for ``user_id`` from the DB.
|
"""Return the current billing tier for ``user_id`` from the DB.
|
||||||
|
|
||||||
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod
|
Falls back to ``'free'`` when no subscription row exists.
|
||||||
when no subscription row exists.
|
|
||||||
"""
|
"""
|
||||||
from app.models import Subscription # noqa: PLC0415
|
from app.models import Subscription # noqa: PLC0415
|
||||||
from app.config.settings import settings # noqa: PLC0415
|
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
tier: str | None = result.scalar_one_or_none()
|
tier: str | None = result.scalar_one_or_none()
|
||||||
if tier is None or tier not in FEATURES:
|
if tier is None or tier not in FEATURES:
|
||||||
return "power" if settings.ENV == "dev" else "free"
|
return "free"
|
||||||
return tier # type: ignore[return-value]
|
return tier # type: ignore[return-value]
|
||||||
|
|
||||||
# ── Feature access ───────────────────────────────────────────────────
|
# ── Feature access ───────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -52,10 +52,6 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
||||||
|
|
||||||
LANGFUSE_SECRET_KEY: str = ""
|
|
||||||
LANGFUSE_PUBLIC_KEY: str = ""
|
|
||||||
LANGFUSE_HOST: str = "https://cloud.langfuse.com"
|
|
||||||
|
|
||||||
ENV: Literal["dev", "prod"] = "dev"
|
ENV: Literal["dev", "prod"] = "dev"
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||||
|
|||||||
@@ -1,30 +0,0 @@
|
|||||||
"""Minimal agent base types retained for compatibility with batch runners."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(ABC):
|
|
||||||
"""Common base for non-chat agents still using the old base contract."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
user_id: str = "",
|
|
||||||
shared_memory: dict[str, Any] | None = None,
|
|
||||||
vector_store_context: list[str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.user_id = user_id
|
|
||||||
self.shared_memory: dict[str, Any] = shared_memory or {}
|
|
||||||
self.vector_store_context: list[str] = vector_store_context or []
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_name(self) -> str: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_description(self) -> str: ...
|
|
||||||
|
|
||||||
@property
|
|
||||||
def skills(self) -> list[str]:
|
|
||||||
return []
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -3,15 +3,20 @@
|
|||||||
Maintains in-memory state for all active Electron → backend WebSocket
|
Maintains in-memory state for all active Electron → backend WebSocket
|
||||||
connections. One connection per user (latest replaces previous).
|
connections. One connection per user (latest replaces previous).
|
||||||
|
|
||||||
The manager handles the **tool-call round-trip** pattern:
|
The manager participates in two interaction patterns:
|
||||||
- Backend sends ``tool_call`` frame → Electron executes the action →
|
|
||||||
returns ``tool_result`` frame.
|
1. **Tool-call round-trip** (bidirectional CRUD):
|
||||||
|
- Backend sends ``tool_call`` frame → Electron executes CRUD → returns
|
||||||
|
``tool_result`` frame.
|
||||||
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
||||||
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
||||||
receive the result dict from Electron.
|
receive the result dict from Electron.
|
||||||
|
|
||||||
This pattern is used by all tools (CRUD, file-system, etc.) via
|
2. **Agent-data streaming** (local directory agent runs):
|
||||||
``execute_on_client()`` in ``ws_context.py``.
|
- Backend sends ``agent_run`` frame → Electron reads files and sends
|
||||||
|
back a stream of ``agent_data`` frames followed by ``agent_complete``.
|
||||||
|
- ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for
|
||||||
|
a specific ``run_id`` so the agent runner can iterate frames.
|
||||||
|
|
||||||
The ``device_manager`` module-level singleton is imported by both the
|
The ``device_manager`` module-level singleton is imported by both the
|
||||||
device WS route and the agent runner.
|
device WS route and the agent runner.
|
||||||
@@ -37,6 +42,8 @@ class DeviceConnection:
|
|||||||
device_id: str
|
device_id: str
|
||||||
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
||||||
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||||
|
# Per-run queues for agent_data / agent_complete frames.
|
||||||
|
agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class DeviceConnectionManager:
|
class DeviceConnectionManager:
|
||||||
@@ -146,6 +153,31 @@ class DeviceConnectionManager:
|
|||||||
if fut is not None and not fut.done():
|
if fut is not None and not fut.done():
|
||||||
fut.set_result(result)
|
fut.set_result(result)
|
||||||
|
|
||||||
|
# ── Agent-data queue ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_agent_data_queue(
|
||||||
|
self, user_id: str, run_id: str
|
||||||
|
) -> asyncio.Queue[dict | None]:
|
||||||
|
"""Return (creating if absent) the queue for *run_id* agent frames.
|
||||||
|
|
||||||
|
The agent runner reads from this queue. The device WS handler writes
|
||||||
|
to it. ``None`` is the sentinel that signals the stream is finished.
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"get_agent_data_queue: user {user_id!r} is not connected"
|
||||||
|
)
|
||||||
|
if run_id not in conn.agent_data_queues:
|
||||||
|
conn.agent_data_queues[run_id] = asyncio.Queue()
|
||||||
|
return conn.agent_data_queues[run_id]
|
||||||
|
|
||||||
|
def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None:
|
||||||
|
"""Remove the queue for *run_id* once a run has completed."""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn:
|
||||||
|
conn.agent_data_queues.pop(run_id, None)
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton — import this everywhere.
|
# Module-level singleton — import this everywhere.
|
||||||
device_manager = DeviceConnectionManager()
|
device_manager = DeviceConnectionManager()
|
||||||
|
|||||||
@@ -1,147 +0,0 @@
|
|||||||
"""Langfuse observability — singleton client and prompt helpers.
|
|
||||||
|
|
||||||
If LANGFUSE_SECRET_KEY / LANGFUSE_PUBLIC_KEY are not set,
|
|
||||||
all helpers are no-ops so the app works without Langfuse configured.
|
|
||||||
|
|
||||||
Usage
|
|
||||||
-----
|
|
||||||
Tracing::
|
|
||||||
|
|
||||||
from app.core.langfuse_client import get_langfuse
|
|
||||||
|
|
||||||
lf = get_langfuse()
|
|
||||||
if lf:
|
|
||||||
with lf.start_as_current_observation(as_type="span", name="my-agent") as span:
|
|
||||||
span.update(input=user_message)
|
|
||||||
# ... do work ...
|
|
||||||
span.update(output=result)
|
|
||||||
lf.flush()
|
|
||||||
|
|
||||||
Prompt management::
|
|
||||||
|
|
||||||
from app.core.langfuse_client import get_prompt_or_fallback
|
|
||||||
|
|
||||||
text, prompt_obj = get_prompt_or_fallback("home_system", FALLBACK_PROMPT)
|
|
||||||
# Use text as the system prompt; pass prompt_obj to generations for linking.
|
|
||||||
|
|
||||||
Linking a prompt to a generation::
|
|
||||||
|
|
||||||
with lf.start_as_current_observation(
|
|
||||||
as_type="generation",
|
|
||||||
name="llm-call",
|
|
||||||
model="gpt-4o",
|
|
||||||
prompt=prompt_obj, # links generation → prompt version in the UI
|
|
||||||
input=messages,
|
|
||||||
) as gen:
|
|
||||||
response = await llm.ainvoke(messages)
|
|
||||||
gen.update(output=response.content, usage=_usage(response))
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_client: Any = None
|
|
||||||
_initialized: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_langfuse() -> Any | None:
|
|
||||||
"""Return the Langfuse singleton, or ``None`` when not configured."""
|
|
||||||
global _client, _initialized
|
|
||||||
if _initialized:
|
|
||||||
return _client
|
|
||||||
_initialized = True
|
|
||||||
|
|
||||||
from app.config.settings import settings # local import to avoid circular deps
|
|
||||||
|
|
||||||
if not settings.LANGFUSE_SECRET_KEY or not settings.LANGFUSE_PUBLIC_KEY:
|
|
||||||
logger.debug("langfuse: not configured — observability disabled")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from langfuse import Langfuse
|
|
||||||
|
|
||||||
_client = Langfuse(
|
|
||||||
secret_key=settings.LANGFUSE_SECRET_KEY,
|
|
||||||
public_key=settings.LANGFUSE_PUBLIC_KEY,
|
|
||||||
host=settings.LANGFUSE_HOST,
|
|
||||||
)
|
|
||||||
logger.info("langfuse: client initialized host=%s", settings.LANGFUSE_HOST)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("langfuse: failed to initialize: %s", exc)
|
|
||||||
_client = None
|
|
||||||
|
|
||||||
return _client
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_or_fallback(name: str, fallback: str) -> tuple[str, Any]:
|
|
||||||
"""Fetch a text prompt from Langfuse; fall back to ``fallback`` on any error.
|
|
||||||
|
|
||||||
Returns ``(raw_template, prompt_obj_or_None)``.
|
|
||||||
|
|
||||||
* ``raw_template`` — the uncompiled template string. Do NOT call ``.format()``
|
|
||||||
on it directly; use :func:`compile_prompt` instead so the correct variable
|
|
||||||
syntax is applied (``{{var}}`` for Langfuse, ``{var}`` for the fallback).
|
|
||||||
* ``prompt_obj`` — the Langfuse prompt object, or ``None`` when Langfuse is
|
|
||||||
unavailable / the fetch failed. Pass this to generation observations so
|
|
||||||
Langfuse links the generation to the exact prompt version in the UI.
|
|
||||||
"""
|
|
||||||
lf = get_langfuse()
|
|
||||||
if lf is None:
|
|
||||||
return fallback, None
|
|
||||||
|
|
||||||
try:
|
|
||||||
prompt = lf.get_prompt(name, label="production", fallback=fallback)
|
|
||||||
# For text-type prompts .prompt holds the raw template string.
|
|
||||||
raw = prompt.prompt if hasattr(prompt, "prompt") and isinstance(prompt.prompt, str) else fallback
|
|
||||||
return raw, prompt
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("langfuse: get_prompt %r failed: %s — using fallback", name, exc)
|
|
||||||
return fallback, None
|
|
||||||
|
|
||||||
|
|
||||||
def compile_prompt(template: str, prompt_obj: Any, **variables: Any) -> str:
|
|
||||||
"""Compile *template* with *variables*, choosing the right syntax.
|
|
||||||
|
|
||||||
* When *prompt_obj* is a real Langfuse prompt object, calls
|
|
||||||
``prompt_obj.compile(**variables)`` which handles ``{{variable}}``
|
|
||||||
substitution as defined in the Langfuse UI.
|
|
||||||
* When *prompt_obj* is ``None`` (Langfuse unavailable or fetch failed),
|
|
||||||
falls back to ``template.format(**variables)`` which handles the
|
|
||||||
``{variable}`` syntax used in the hardcoded fallback strings.
|
|
||||||
|
|
||||||
This keeps callers oblivious to which syntax is in use.
|
|
||||||
"""
|
|
||||||
if prompt_obj is not None:
|
|
||||||
try:
|
|
||||||
compiled = prompt_obj.compile(**variables)
|
|
||||||
# compile() returns a string for text prompts.
|
|
||||||
if isinstance(compiled, str):
|
|
||||||
return compiled
|
|
||||||
# Chat prompts return a list of dicts — join text parts.
|
|
||||||
if isinstance(compiled, list):
|
|
||||||
return "\n".join(
|
|
||||||
m.get("content", "") for m in compiled if isinstance(m, dict)
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"langfuse: compile failed for prompt %r: %s — falling back to .format()",
|
|
||||||
getattr(prompt_obj, "name", "?"),
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
return template.format(**variables)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_usage(response: Any) -> dict[str, int]:
|
|
||||||
"""Extract token usage from a LangChain AI message into Langfuse format."""
|
|
||||||
meta = getattr(response, "usage_metadata", None)
|
|
||||||
if not meta:
|
|
||||||
return {}
|
|
||||||
return {
|
|
||||||
"input": int(meta.get("input_tokens", 0)),
|
|
||||||
"output": int(meta.get("output_tokens", 0)),
|
|
||||||
"total": int(meta.get("total_tokens", 0)),
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()``
|
Every agent and the deep-agent supervisors call ``get_llm()`` or ``get_router_llm()``
|
||||||
instead of directly constructing a provider-specific class. The model string
|
instead of directly constructing a provider-specific class. The model string
|
||||||
follows the `LiteLLM model naming convention
|
follows the `LiteLLM model naming convention
|
||||||
<https://docs.litellm.ai/docs/providers>`_:
|
<https://docs.litellm.ai/docs/providers>`_:
|
||||||
@@ -18,7 +18,6 @@ Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import litellm
|
import litellm
|
||||||
@@ -33,14 +32,6 @@ from app.config.settings import settings
|
|||||||
# Drop them silently instead of raising UnsupportedParamsError.
|
# Drop them silently instead of raising UnsupportedParamsError.
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
# Some provider responses include a plain dict in the `usage` field where a
|
|
||||||
# richer Pydantic model is expected. This warning is noisy but non-fatal.
|
|
||||||
warnings.filterwarnings(
|
|
||||||
"ignore",
|
|
||||||
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
|
||||||
category=UserWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _api_key_for_model(model: str) -> str | None:
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
"""Return the most appropriate API key for the given LiteLLM model string."""
|
"""Return the most appropriate API key for the given LiteLLM model string."""
|
||||||
|
|||||||
@@ -43,21 +43,15 @@ _PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
|||||||
|
|
||||||
|
|
||||||
class MemoryMiddleware:
|
class MemoryMiddleware:
|
||||||
"""Enrich orchestrator context with memory and persist interactions after."""
|
"""Enrich agent context with memory and persist interactions after."""
|
||||||
|
|
||||||
def __init__(self, db: AsyncSession) -> None:
|
def __init__(self, db: AsyncSession) -> None:
|
||||||
self._db = db
|
self._db = db
|
||||||
|
|
||||||
# ── Public API ────────────────────────────────────────────────────────────
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def enrich_context(
|
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
|
||||||
self,
|
"""Build memory context dict to inject into the agent before LLM call.
|
||||||
user_id: str,
|
|
||||||
message: str,
|
|
||||||
trace_id: str | None = None,
|
|
||||||
session_id: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Build memory context dict to inject into the orchestrator before LLM call.
|
|
||||||
|
|
||||||
Returns a dict with keys:
|
Returns a dict with keys:
|
||||||
core_memory — {key: plaintext_value, ...}
|
core_memory — {key: plaintext_value, ...}
|
||||||
@@ -71,21 +65,9 @@ class MemoryMiddleware:
|
|||||||
|
|
||||||
core = await self._load_core(user_id, fernet)
|
core = await self._load_core(user_id, fernet)
|
||||||
associative = await self._load_associative(user_id, message, fernet)
|
associative = await self._load_associative(user_id, message, fernet)
|
||||||
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
episodic = await self._load_episodic(user_id, fernet)
|
||||||
proactive = await self._load_proactive(user_id, fernet)
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
|
||||||
user_dbg = await self._get_user_debug(user_id)
|
|
||||||
logger.info(
|
|
||||||
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
|
||||||
trace_id or "-",
|
|
||||||
user_id,
|
|
||||||
user_dbg.get("tier") or "-",
|
|
||||||
len(core),
|
|
||||||
len(associative),
|
|
||||||
len(episodic),
|
|
||||||
len(proactive),
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"core_memory": core,
|
"core_memory": core,
|
||||||
"associative_memory": associative,
|
"associative_memory": associative,
|
||||||
@@ -99,7 +81,6 @@ class MemoryMiddleware:
|
|||||||
session_id: str,
|
session_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
response: str,
|
response: str,
|
||||||
trace_id: str | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Summarise and store a completed interaction in episodic memory.
|
"""Summarise and store a completed interaction in episodic memory.
|
||||||
|
|
||||||
@@ -122,19 +103,11 @@ class MemoryMiddleware:
|
|||||||
self._db.add(row)
|
self._db.add(row)
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
user_dbg = await self._get_user_debug(user_id)
|
|
||||||
logger.info(
|
|
||||||
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
|
||||||
trace_id or "-",
|
|
||||||
user_id,
|
|
||||||
user_dbg.get("tier") or "-",
|
|
||||||
session_id,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
|
||||||
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
async def update_core(self, user_id: str, key: str, value: str) -> None:
|
||||||
"""Upsert a core memory key/value for a user."""
|
"""Upsert a core memory key/value for a user."""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
if fernet is None:
|
if fernet is None:
|
||||||
@@ -160,176 +133,10 @@ class MemoryMiddleware:
|
|||||||
))
|
))
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
user_dbg = await self._get_user_debug(user_id)
|
|
||||||
logger.info(
|
|
||||||
"memory: update_core trace=%s user=%s tier=%s key=%s",
|
|
||||||
trace_id or "-",
|
|
||||||
user_id,
|
|
||||||
user_dbg.get("tier") or "-",
|
|
||||||
key,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
|
||||||
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
|
||||||
"""Return core memory as editable blocks (label/value)."""
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryCore)
|
|
||||||
.where(MemoryCore.user_id == user_id)
|
|
||||||
.order_by(MemoryCore.key.asc())
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
out: list[dict[str, str]] = []
|
|
||||||
for row in rows:
|
|
||||||
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
|
||||||
if plaintext is not None:
|
|
||||||
out.append({"label": row.key, "value": plaintext})
|
|
||||||
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
|
|
||||||
return out
|
|
||||||
|
|
||||||
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
|
||||||
"""Return a single core memory block value by label."""
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryCore).where(
|
|
||||||
MemoryCore.user_id == user_id,
|
|
||||||
MemoryCore.key == label,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
|
|
||||||
return None
|
|
||||||
value = _safe_decrypt(fernet, row.value_encrypted)
|
|
||||||
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
|
|
||||||
return value
|
|
||||||
|
|
||||||
async def delete_core(self, user_id: str, label: str) -> bool:
|
|
||||||
"""Delete a core memory block by label. Returns True if deleted."""
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryCore).where(
|
|
||||||
MemoryCore.user_id == user_id,
|
|
||||||
MemoryCore.key == label,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
|
|
||||||
return False
|
|
||||||
|
|
||||||
await self._db.delete(row)
|
|
||||||
try:
|
|
||||||
await self._db.commit()
|
|
||||||
logger.info("memory: delete_core user=%s label=%s", user_id, label)
|
|
||||||
return True
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
|
||||||
await self._db.rollback()
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
|
||||||
"""Append content to a core block, creating it if missing."""
|
|
||||||
current = await self.get_core_block(user_id, label)
|
|
||||||
if current is None:
|
|
||||||
await self.update_core(user_id, label, content)
|
|
||||||
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
|
|
||||||
return
|
|
||||||
await self.update_core(user_id, label, f"{current}\n{content}")
|
|
||||||
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
|
|
||||||
|
|
||||||
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
|
||||||
"""Replace one exact string inside a core block. Returns False if not found."""
|
|
||||||
current = await self.get_core_block(user_id, label)
|
|
||||||
if current is None or old not in current:
|
|
||||||
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
|
|
||||||
return False
|
|
||||||
await self.update_core(user_id, label, current.replace(old, new, 1))
|
|
||||||
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
|
||||||
"""Insert a long-term archival memory entry."""
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
encrypted = _encrypt(fernet, content)
|
|
||||||
row = MemoryAssociative(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
content_encrypted=encrypted,
|
|
||||||
embedding=None,
|
|
||||||
entity_type=source,
|
|
||||||
entity_id=None,
|
|
||||||
)
|
|
||||||
self._db.add(row)
|
|
||||||
try:
|
|
||||||
await self._db.commit()
|
|
||||||
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
|
||||||
await self._db.rollback()
|
|
||||||
|
|
||||||
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
|
||||||
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryAssociative)
|
|
||||||
.where(MemoryAssociative.user_id == user_id)
|
|
||||||
.order_by(MemoryAssociative.updated_at.desc())
|
|
||||||
.limit(100)
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
needle = query.strip().lower()
|
|
||||||
out: list[str] = []
|
|
||||||
for row in rows:
|
|
||||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
|
||||||
if plaintext is None:
|
|
||||||
continue
|
|
||||||
if not needle or needle in plaintext.lower():
|
|
||||||
out.append(plaintext)
|
|
||||||
if len(out) >= max(top_k, 1):
|
|
||||||
break
|
|
||||||
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
|
||||||
return out
|
|
||||||
|
|
||||||
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
|
||||||
"""Search recall memory (episodic summaries) by keyword."""
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryEpisodic)
|
|
||||||
.where(MemoryEpisodic.user_id == user_id)
|
|
||||||
.order_by(MemoryEpisodic.created_at.desc())
|
|
||||||
.limit(100)
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
needle = query.strip().lower()
|
|
||||||
out: list[str] = []
|
|
||||||
for row in rows:
|
|
||||||
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
|
||||||
if plaintext is None:
|
|
||||||
continue
|
|
||||||
if not needle or needle in plaintext.lower():
|
|
||||||
out.append(plaintext)
|
|
||||||
if len(out) >= max(top_k, 1):
|
|
||||||
break
|
|
||||||
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
|
||||||
return out
|
|
||||||
|
|
||||||
# ── Private helpers ───────────────────────────────────────────────────────
|
# ── Private helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||||
@@ -341,16 +148,6 @@ class MemoryMiddleware:
|
|||||||
return None
|
return None
|
||||||
return Fernet(user.encryption_key.encode())
|
return Fernet(user.encryption_key.encode())
|
||||||
|
|
||||||
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
|
||||||
"""Load lightweight user debug fields for trace logs."""
|
|
||||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
if user is None:
|
|
||||||
return {"tier": None}
|
|
||||||
return {
|
|
||||||
"tier": user.tier,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||||
@@ -386,17 +183,10 @@ class MemoryMiddleware:
|
|||||||
out.append(plaintext)
|
out.append(plaintext)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
async def _load_episodic(
|
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
fernet: Fernet,
|
|
||||||
session_id: str | None = None,
|
|
||||||
) -> list[str]:
|
|
||||||
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
|
||||||
if session_id:
|
|
||||||
query = query.where(MemoryEpisodic.session_id == session_id)
|
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
query
|
select(MemoryEpisodic)
|
||||||
|
.where(MemoryEpisodic.user_id == user_id)
|
||||||
.order_by(MemoryEpisodic.created_at.desc())
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
.limit(_EPISODIC_RECENT_N)
|
.limit(_EPISODIC_RECENT_N)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,47 +1,141 @@
|
|||||||
"""Output formatter for deep-agent stream events."""
|
"""Output Formatter — transforms deep-agent event streams into WS frame sequences.
|
||||||
|
|
||||||
|
Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``:
|
||||||
|
* ``("token", str)`` — supervisor text token
|
||||||
|
* ``("tool_end", dict)`` — sub-agent finished: ``{name, result}``
|
||||||
|
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
|
||||||
|
|
||||||
|
HomeFormatter:
|
||||||
|
* Streams text tokens as-is → emits ``WsStreamText``
|
||||||
|
(text may contain inline ``<type>[id,...]</type>`` entity tags
|
||||||
|
for the frontend to parse and render as interactive components)
|
||||||
|
* Attaches mutations → injects into ``WsStreamEnd``
|
||||||
|
|
||||||
|
FloatingFormatter:
|
||||||
|
* Sniffs first ``tool_end`` name → emits ``WsFloatingDomain``
|
||||||
|
* Streams text tokens → emits ``WsStreamText``
|
||||||
|
* Attaches mutations → injects into ``WsStreamEnd``
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
from app.schemas import (
|
||||||
|
WsFloatingDomain,
|
||||||
|
WsStreamEnd,
|
||||||
|
WsStreamStart,
|
||||||
|
WsStreamText,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Map sub-agent tool name → floating domain / entity type
|
||||||
|
_AGENT_DOMAIN: dict[str, str] = {
|
||||||
|
"task_agent": "tasks",
|
||||||
|
"timeline_agent": "timelines",
|
||||||
|
"note_agent": "notes",
|
||||||
|
"project_agent": "projects",
|
||||||
|
}
|
||||||
|
|
||||||
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||||
|
|
||||||
|
|
||||||
class StreamFormatter:
|
class HomeFormatter:
|
||||||
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
"""Consumes a deep-agent event stream and yields WS frames for the Home view.
|
||||||
|
|
||||||
|
Text tokens are forwarded as-is via ``WsStreamText``. The supervisor
|
||||||
|
embeds ``<type>[id1,id2]</type>`` entity tags inline — the frontend
|
||||||
|
is responsible for parsing those and rendering interactive components.
|
||||||
|
Mutations are attached to ``WsStreamEnd``.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, request_id: str) -> None:
|
def __init__(self, request_id: str) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
|
self._mutations: list[dict] = []
|
||||||
|
|
||||||
async def format(
|
async def format(
|
||||||
self,
|
self,
|
||||||
event_stream: AsyncGenerator[tuple[str, Any], None],
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
) -> AsyncGenerator[WsFrame, None]:
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
started = False
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
|
||||||
async for event_type, data in event_stream:
|
async for event_type, data in event_stream:
|
||||||
if event_type == "floating_domain":
|
if event_type == "token":
|
||||||
if isinstance(data, dict):
|
if data:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=data)
|
||||||
|
|
||||||
|
elif event_type == "mutations":
|
||||||
|
self._mutations = data or []
|
||||||
|
|
||||||
|
yield WsStreamEnd(
|
||||||
|
request_id=self.request_id,
|
||||||
|
mutations=[
|
||||||
|
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
||||||
|
for m in self._mutations
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FloatingFormatter:
|
||||||
|
"""Consumes a deep-agent event stream and yields WS frames for the Floating view.
|
||||||
|
|
||||||
|
Sniffs the first ``tool_end`` event name to derive the domain (e.g.
|
||||||
|
``task_agent`` → ``"tasks"``), then streams text tokens as plain
|
||||||
|
``WsStreamText``. No block parsing for floating context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
self._mutations: list[dict] = []
|
||||||
|
|
||||||
|
async def format(
|
||||||
|
self,
|
||||||
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
|
domain_sent = False
|
||||||
|
|
||||||
|
async for event_type, data in event_stream:
|
||||||
|
if event_type == "tool_end" and not domain_sent:
|
||||||
|
# Sniff domain from the first sub-agent that completes
|
||||||
|
name = data.get("name", "")
|
||||||
|
domain = _AGENT_DOMAIN.get(name, "tasks")
|
||||||
yield WsFloatingDomain(
|
yield WsFloatingDomain(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
domain=data,
|
domain=domain, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
continue
|
|
||||||
|
|
||||||
if event_type != "token":
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not started:
|
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
started = True
|
domain_sent = True
|
||||||
|
|
||||||
text = str(data or "")
|
elif event_type == "token":
|
||||||
if text:
|
if not domain_sent:
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=text)
|
# First token arrived before any tool_end — default domain
|
||||||
|
yield WsFloatingDomain(
|
||||||
if not started:
|
request_id=self.request_id,
|
||||||
|
domain="tasks", # type: ignore[arg-type]
|
||||||
|
)
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
yield WsStreamEnd(request_id=self.request_id)
|
domain_sent = True
|
||||||
|
if data:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=data)
|
||||||
|
|
||||||
|
elif event_type == "mutations":
|
||||||
|
self._mutations = data or []
|
||||||
|
|
||||||
|
# If no events triggered domain_sent (edge case), still emit structure
|
||||||
|
if not domain_sent:
|
||||||
|
yield WsFloatingDomain(
|
||||||
|
request_id=self.request_id,
|
||||||
|
domain="tasks", # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
|
||||||
|
yield WsStreamEnd(
|
||||||
|
request_id=self.request_id,
|
||||||
|
mutations=[
|
||||||
|
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
||||||
|
for m in self._mutations
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,104 +0,0 @@
|
|||||||
"""Preprocessor registry: detect content type and dispatch to handlers.
|
|
||||||
|
|
||||||
Public API
|
|
||||||
----------
|
|
||||||
detect_content_type(filename, raw_content) -> str
|
|
||||||
Heuristic detection based on file extension and content patterns.
|
|
||||||
|
|
||||||
preprocess(content_type, raw_content) -> PreprocessResult
|
|
||||||
Dispatch to the appropriate handler.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
from app.core.preprocessors.base import PreprocessResult
|
|
||||||
|
|
||||||
# ── Heuristics ────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
# Patterns that strongly suggest an email HTML file
|
|
||||||
_EMAIL_SIGNALS = re.compile(
|
|
||||||
r"(Subject:|From:|To:|Date:|Sent:|MIME-Version:|Content-Type:\s*text/html)",
|
|
||||||
re.IGNORECASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Patterns that suggest a generic HTML page (not an email)
|
|
||||||
_GENERIC_HTML_SIGNALS = re.compile(
|
|
||||||
r"<(nav|main|header|footer|article|section)\b",
|
|
||||||
re.IGNORECASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def detect_content_type(filename: str, raw_content: str) -> str:
|
|
||||||
"""Return a content-type string for the given file.
|
|
||||||
|
|
||||||
Supported types: ``"email_html"``, ``"generic_html"``,
|
|
||||||
``"plain_text"``, ``"unknown"``.
|
|
||||||
"""
|
|
||||||
ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
|
|
||||||
|
|
||||||
if ext == "txt":
|
|
||||||
return "plain_text"
|
|
||||||
|
|
||||||
if ext in ("html", "htm", "eml", "mhtml", "mht"):
|
|
||||||
# Prefer email detection over generic HTML
|
|
||||||
if _EMAIL_SIGNALS.search(raw_content[:4096]):
|
|
||||||
return "email_html"
|
|
||||||
if _GENERIC_HTML_SIGNALS.search(raw_content[:4096]) or "<html" in raw_content[:200].lower():
|
|
||||||
return "generic_html"
|
|
||||||
# .html without clear signals — check for any email header
|
|
||||||
if re.search(r"^(From|To|Subject|Date):", raw_content[:2048], re.MULTILINE | re.IGNORECASE):
|
|
||||||
return "email_html"
|
|
||||||
return "generic_html"
|
|
||||||
|
|
||||||
# Plain text files with email headers
|
|
||||||
if ext in ("", "txt") or not ext:
|
|
||||||
if _EMAIL_SIGNALS.search(raw_content[:4096]):
|
|
||||||
return "email_html"
|
|
||||||
|
|
||||||
# Detect binary content
|
|
||||||
try:
|
|
||||||
raw_content.encode("utf-8")
|
|
||||||
except (UnicodeEncodeError, AttributeError):
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
# Non-text bytes heuristic: high ratio of non-printable chars
|
|
||||||
sample = raw_content[:512]
|
|
||||||
non_printable = sum(1 for c in sample if ord(c) < 32 and c not in "\r\n\t")
|
|
||||||
if len(sample) > 0 and non_printable / len(sample) > 0.1:
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Generic fallback handler ──────────────────────────────────────────
|
|
||||||
|
|
||||||
def _preprocess_generic(raw_content: str, content_type: str) -> PreprocessResult:
|
|
||||||
"""Strip HTML tags if present, return text as-is."""
|
|
||||||
try:
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
text = BeautifulSoup(raw_content, "html.parser").get_text(separator="\n")
|
|
||||||
except ImportError:
|
|
||||||
# No BeautifulSoup — strip tags with a simple regex
|
|
||||||
text = re.sub(r"<[^>]+>", "", raw_content)
|
|
||||||
|
|
||||||
text = re.sub(r"\n{3,}", "\n\n", text).strip()
|
|
||||||
return PreprocessResult(content_type=content_type, clean_text=text, metadata={})
|
|
||||||
|
|
||||||
|
|
||||||
# ── Dispatch ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def preprocess(content_type: str, raw_content: str) -> PreprocessResult:
|
|
||||||
"""Dispatch *raw_content* to the handler registered for *content_type*.
|
|
||||||
|
|
||||||
Falls back to the generic handler for unknown types.
|
|
||||||
"""
|
|
||||||
if content_type == "email_html":
|
|
||||||
from app.core.preprocessors.email_html import preprocess_email_html
|
|
||||||
return preprocess_email_html(raw_content)
|
|
||||||
|
|
||||||
return _preprocess_generic(raw_content, content_type)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["detect_content_type", "preprocess", "PreprocessResult"]
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
"""Base types for the preprocessor system."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PreprocessResult:
|
|
||||||
"""Output of a preprocessor handler.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
content_type:
|
|
||||||
The detected content type (e.g. ``"email_html"``, ``"plain_text"``).
|
|
||||||
clean_text:
|
|
||||||
Human-readable text stripped of markup/binary noise.
|
|
||||||
metadata:
|
|
||||||
Dict of extracted metadata (keys vary by handler).
|
|
||||||
Common keys: ``subject``, ``from``, ``to``, ``date``, ``filename``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
content_type: str
|
|
||||||
clean_text: str
|
|
||||||
metadata: dict = field(default_factory=dict)
|
|
||||||
@@ -1,111 +0,0 @@
|
|||||||
"""Preprocessor for email HTML files.
|
|
||||||
|
|
||||||
Handles:
|
|
||||||
- HTML stripping via BeautifulSoup
|
|
||||||
- Metadata extraction (Subject, From, To, Date)
|
|
||||||
- Thread splitting — isolates the latest reply
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from app.core.preprocessors.base import PreprocessResult
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# ── Thread split markers ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
# Matches patterns like:
|
|
||||||
# "On Mon, Apr 7, 2026 at 10:00 AM, Alice <alice@co.com> wrote:"
|
|
||||||
# "-----Original Message-----"
|
|
||||||
# "> " (plain-text quote prefix)
|
|
||||||
_THREAD_PATTERNS = [
|
|
||||||
re.compile(r"^On\s+.+wrote\s*:", re.IGNORECASE | re.MULTILINE),
|
|
||||||
re.compile(r"^-{3,}\s*(original message|forwarded message)\s*-{3,}", re.IGNORECASE | re.MULTILINE),
|
|
||||||
re.compile(r"^>{1,}\s+\S", re.MULTILINE),
|
|
||||||
re.compile(r"^From:\s+.+\nSent:\s+", re.IGNORECASE | re.MULTILINE),
|
|
||||||
]
|
|
||||||
|
|
||||||
# ── Metadata patterns (applied on raw HTML / plain fallback) ──────────
|
|
||||||
|
|
||||||
_META_PATTERNS: dict[str, list[re.Pattern]] = {
|
|
||||||
"subject": [
|
|
||||||
re.compile(r"<title>(.+?)</title>", re.IGNORECASE | re.DOTALL),
|
|
||||||
re.compile(r"Subject:\s*(.+)", re.IGNORECASE),
|
|
||||||
],
|
|
||||||
"from": [
|
|
||||||
re.compile(r'<meta[^>]+name=["\']?from["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
|
||||||
re.compile(r"From:\s*(.+)", re.IGNORECASE),
|
|
||||||
],
|
|
||||||
"to": [
|
|
||||||
re.compile(r'<meta[^>]+name=["\']?to["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
|
||||||
re.compile(r"To:\s*(.+)", re.IGNORECASE),
|
|
||||||
],
|
|
||||||
"date": [
|
|
||||||
re.compile(r'<meta[^>]+name=["\']?date["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
|
||||||
re.compile(r"Date:\s*(.+)", re.IGNORECASE),
|
|
||||||
re.compile(r"Sent:\s*(.+)", re.IGNORECASE),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_metadata(raw_html: str, text: str) -> dict:
|
|
||||||
"""Extract Subject/From/To/Date from raw HTML or plain text."""
|
|
||||||
metadata: dict[str, str] = {}
|
|
||||||
for field, patterns in _META_PATTERNS.items():
|
|
||||||
for pat in patterns:
|
|
||||||
m = pat.search(raw_html) or pat.search(text)
|
|
||||||
if m:
|
|
||||||
metadata[field] = m.group(1).strip()
|
|
||||||
break
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
|
|
||||||
def _split_thread(text: str) -> str:
|
|
||||||
"""Return only the latest message in a threaded email."""
|
|
||||||
earliest_pos: int | None = None
|
|
||||||
for pat in _THREAD_PATTERNS:
|
|
||||||
m = pat.search(text)
|
|
||||||
if m and (earliest_pos is None or m.start() < earliest_pos):
|
|
||||||
earliest_pos = m.start()
|
|
||||||
|
|
||||||
if earliest_pos is not None and earliest_pos > 0:
|
|
||||||
return text[:earliest_pos].strip()
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_email_html(raw_content: str) -> PreprocessResult:
|
|
||||||
"""Strip HTML, extract metadata, split thread from an email HTML file."""
|
|
||||||
try:
|
|
||||||
from bs4 import BeautifulSoup # lazy import — optional dep
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(
|
|
||||||
"beautifulsoup4 is required for email_html preprocessing. "
|
|
||||||
"Install it with: pip install beautifulsoup4"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
# Parse with lxml if available, fall back to html.parser
|
|
||||||
try:
|
|
||||||
soup = BeautifulSoup(raw_content, "lxml")
|
|
||||||
except Exception:
|
|
||||||
soup = BeautifulSoup(raw_content, "html.parser")
|
|
||||||
|
|
||||||
# Remove noise tags
|
|
||||||
for tag in soup(["style", "script", "head", "noscript"]):
|
|
||||||
tag.decompose()
|
|
||||||
|
|
||||||
clean_text = soup.get_text(separator="\n")
|
|
||||||
# Collapse excessive blank lines
|
|
||||||
clean_text = re.sub(r"\n{3,}", "\n\n", clean_text).strip()
|
|
||||||
|
|
||||||
metadata = _extract_metadata(raw_content, clean_text)
|
|
||||||
latest_message = _split_thread(clean_text)
|
|
||||||
|
|
||||||
return PreprocessResult(
|
|
||||||
content_type="email_html",
|
|
||||||
clean_text=latest_message,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
@@ -7,18 +7,21 @@ The callback sends a `tool_call` WS frame and awaits the `tool_result`.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import Any, Callable, Coroutine
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Holds the execute callback for the current WS session.
|
# Holds the execute callback for the current WS session.
|
||||||
# Set by the chat WS handler before the orchestrator runs; cleared after.
|
# Set by the chat WS handler before the deep agent runs; cleared after.
|
||||||
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
||||||
"_client_executor"
|
"_client_executor"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optional collector that captures raw execute_on_client results.
|
# Optional collector that captures raw execute_on_client results.
|
||||||
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
|
# Set by the deep agent tool loop to capture CRUD mutations.
|
||||||
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||||
"_tool_result_collector", default=None
|
"_tool_result_collector", default=None
|
||||||
)
|
)
|
||||||
@@ -81,12 +84,17 @@ async def execute_on_client(
|
|||||||
if limit is not None:
|
if limit is not None:
|
||||||
payload["limit"] = limit
|
payload["limit"] = limit
|
||||||
|
|
||||||
|
logger.info("execute_on_client: sending payload action=%s table=%s id=%s", action, table, payload["id"])
|
||||||
result = await callback(payload)
|
result = await callback(payload)
|
||||||
|
if result is None:
|
||||||
|
logger.error("execute_on_client: callback returned None for action=%s table=%s id=%s", action, table, payload["id"])
|
||||||
|
else:
|
||||||
|
logger.info("execute_on_client: got result type=%s keys=%s", type(result).__name__, list(result.keys()) if isinstance(result, dict) else "N/A")
|
||||||
collector = _tool_result_collector.get(None)
|
collector = _tool_result_collector.get(None)
|
||||||
if collector is not None:
|
if collector is not None and action in ("insert", "update", "delete"):
|
||||||
collector.append({
|
collector.append({
|
||||||
"action": action,
|
"action": action,
|
||||||
"table": table,
|
"table": table,
|
||||||
"data": result,
|
"data": data or {},
|
||||||
})
|
})
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -18,9 +18,7 @@ from app.config.settings import settings
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Startup: ensure agent tool modules are loaded.
|
# Startup: initialise DB connection pool
|
||||||
import app.agents # noqa: F401
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Shutdown: dispose SQLAlchemy connection pool
|
# Shutdown: dispose SQLAlchemy connection pool
|
||||||
@@ -50,7 +48,7 @@ def create_app() -> FastAPI:
|
|||||||
app.add_middleware(SanitizerMiddleware)
|
app.add_middleware(SanitizerMiddleware)
|
||||||
app.add_middleware(TierRateLimitMiddleware)
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
from app.api.routes import agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
|
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api/v1")
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
app.include_router(chat.router, prefix="/api/v1")
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
@@ -60,6 +58,7 @@ def create_app() -> FastAPI:
|
|||||||
app.include_router(plugins.router, prefix="/api/v1")
|
app.include_router(plugins.router, prefix="/api/v1")
|
||||||
app.include_router(billing.router, prefix="/api/v1")
|
app.include_router(billing.router, prefix="/api/v1")
|
||||||
app.include_router(agents.router, prefix="/api/v1")
|
app.include_router(agents.router, prefix="/api/v1")
|
||||||
|
app.include_router(agent_setup.router, prefix="/api/v1")
|
||||||
app.include_router(device_ws.router, prefix="/api/v1")
|
app.include_router(device_ws.router, prefix="/api/v1")
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
|
|||||||
@@ -296,7 +296,6 @@ class LocalAgentConfig(Base):
|
|||||||
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
agent_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
|
||||||
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
||||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||||
|
|||||||
167
app/schemas.py
167
app/schemas.py
@@ -142,6 +142,9 @@ class WsFrameType(str, Enum):
|
|||||||
tool_result = "tool_result"
|
tool_result = "tool_result"
|
||||||
final = "final"
|
final = "final"
|
||||||
ping = "ping"
|
ping = "ping"
|
||||||
|
agent_run = "agent_run"
|
||||||
|
agent_data = "agent_data"
|
||||||
|
agent_complete = "agent_complete"
|
||||||
device_hello = "device_hello"
|
device_hello = "device_hello"
|
||||||
# ── v3 frame types ─────────────────────────────────────────────────
|
# ── v3 frame types ─────────────────────────────────────────────────
|
||||||
home_request = "home_request"
|
home_request = "home_request"
|
||||||
@@ -153,10 +156,6 @@ class WsFrameType(str, Enum):
|
|||||||
data_request = "data_request"
|
data_request = "data_request"
|
||||||
data_response = "data_response"
|
data_response = "data_response"
|
||||||
mutation = "mutation"
|
mutation = "mutation"
|
||||||
# ── v4 journey frame types ────────────────────────────────────────
|
|
||||||
journey_start = "journey_start"
|
|
||||||
journey_message = "journey_message"
|
|
||||||
journey_reply = "journey_reply"
|
|
||||||
|
|
||||||
|
|
||||||
class WsToolCall(BaseModel):
|
class WsToolCall(BaseModel):
|
||||||
@@ -209,6 +208,31 @@ class WsDeviceHello(BaseModel):
|
|||||||
agent_ids: list[str] = Field(default_factory=list)
|
agent_ids: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class WsAgentRun(BaseModel):
|
||||||
|
"""Server → Client: trigger an agent run on the connected device."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.agent_run] = WsFrameType.agent_run
|
||||||
|
run_id: str
|
||||||
|
agent_id: str
|
||||||
|
config: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class WsAgentData(BaseModel):
|
||||||
|
"""Client → Server: files read by the local agent."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.agent_data] = WsFrameType.agent_data
|
||||||
|
run_id: str
|
||||||
|
files: list[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class WsAgentComplete(BaseModel):
|
||||||
|
"""Client → Server: Electron signals it has finished reading files."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.agent_complete] = WsFrameType.agent_complete
|
||||||
|
run_id: str
|
||||||
|
files_read: int
|
||||||
|
errors: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
||||||
|
|
||||||
@@ -255,14 +279,7 @@ class WsStreamEnd(BaseModel):
|
|||||||
|
|
||||||
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||||
request_id: str
|
request_id: str
|
||||||
|
mutations: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
class WsDomain(BaseModel):
|
|
||||||
"""Structured floating domain payload for UI routing decisions."""
|
|
||||||
|
|
||||||
type: Literal["task", "timeline", "project", "node"]
|
|
||||||
id: str | None = None
|
|
||||||
section: Literal["task", "timeline", "note"] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class WsFloatingDomain(BaseModel):
|
class WsFloatingDomain(BaseModel):
|
||||||
@@ -270,28 +287,7 @@ class WsFloatingDomain(BaseModel):
|
|||||||
|
|
||||||
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
||||||
request_id: str
|
request_id: str
|
||||||
domain: WsDomain
|
domain: Literal["tasks", "timelines", "notes", "projects"]
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Config V2 ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class ContentTypeConfig(BaseModel):
|
|
||||||
"""Per-type extraction config produced by the journey chatbot."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
label: str = ""
|
|
||||||
detection_hint: str = ""
|
|
||||||
preprocessing: str = "generic" # handler name: "email_html", "plain_text", ...
|
|
||||||
extraction_prompt: str
|
|
||||||
|
|
||||||
|
|
||||||
class AgentConfig(BaseModel):
|
|
||||||
"""Structured agent configuration (replaces freeform prompt_template)."""
|
|
||||||
|
|
||||||
content_types: list[ContentTypeConfig] = []
|
|
||||||
global_rules: list[str] = []
|
|
||||||
data_types: list[str] = []
|
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Catalog ─────────────────────────────────────────────────────
|
# ── Agent Catalog ─────────────────────────────────────────────────────
|
||||||
@@ -300,28 +296,84 @@ class AgentCatalogItem(BaseModel):
|
|||||||
type: str
|
type: str
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
|
config_schema: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class AgentCreationCheckRequest(BaseModel):
|
# ── Local Agent Config ────────────────────────────────────────────────
|
||||||
active_agents: int = Field(ge=0, default=0)
|
|
||||||
|
class LocalAgentConfigCreate(BaseModel):
|
||||||
|
name: str
|
||||||
|
device_id: str
|
||||||
|
directory_paths: list[str]
|
||||||
|
data_types: list[str]
|
||||||
|
prompt_template: str
|
||||||
|
file_extensions: list[str]
|
||||||
|
schedule_cron: str
|
||||||
|
|
||||||
|
|
||||||
class AgentCreationCheckResponse(BaseModel):
|
class LocalAgentConfigUpdate(BaseModel):
|
||||||
allowed: bool
|
name: str | None = None
|
||||||
tier: BillingTier
|
device_id: str | None = None
|
||||||
active_agents: int
|
directory_paths: list[str] | None = None
|
||||||
limit: int
|
data_types: list[str] | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
file_extensions: list[str] | None = None
|
||||||
|
schedule_cron: str | None = None
|
||||||
|
enabled: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
class AgentTriggerRequest(BaseModel):
|
class LocalAgentConfigResponse(BaseModel):
|
||||||
directory: str = Field(min_length=1)
|
id: str
|
||||||
device_id: str = Field(default="")
|
name: str
|
||||||
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
|
device_id: str
|
||||||
what_to_extract: list[str] = Field(min_length=1)
|
directory_paths: list[str]
|
||||||
actions_by_type: dict[str, list[str]] | None = None
|
data_types: list[str]
|
||||||
batch_interval: str = Field(min_length=1)
|
prompt_template: str
|
||||||
custom_agent_prompt: str = Field(min_length=1)
|
file_extensions: list[str]
|
||||||
active_agents: int = Field(ge=0, default=0)
|
schedule_cron: str
|
||||||
|
enabled: bool
|
||||||
|
last_run_at: int | None
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud Agent Config ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class CloudAgentConfigCreate(BaseModel):
|
||||||
|
provider: Literal["gmail", "teams", "outlook"]
|
||||||
|
name: str
|
||||||
|
data_types: list[str]
|
||||||
|
prompt_template: str
|
||||||
|
oauth_token_encrypted: str
|
||||||
|
schedule_cron: str
|
||||||
|
filter_config: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CloudAgentConfigUpdate(BaseModel):
|
||||||
|
provider: Literal["gmail", "teams", "outlook"] | None = None
|
||||||
|
name: str | None = None
|
||||||
|
data_types: list[str] | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
oauth_token_encrypted: str | None = None
|
||||||
|
schedule_cron: str | None = None
|
||||||
|
filter_config: dict[str, Any] | None = None
|
||||||
|
enabled: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CloudAgentConfigResponse(BaseModel):
|
||||||
|
"""oauth_token_encrypted is intentionally excluded — never returned to clients."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
provider: Literal["gmail", "teams", "outlook"]
|
||||||
|
name: str
|
||||||
|
data_types: list[str]
|
||||||
|
prompt_template: str
|
||||||
|
schedule_cron: str
|
||||||
|
filter_config: dict[str, Any] | None
|
||||||
|
enabled: bool
|
||||||
|
last_run_at: int | None
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Run Log ─────────────────────────────────────────────────────
|
# ── Agent Run Log ─────────────────────────────────────────────────────
|
||||||
@@ -340,3 +392,18 @@ class AgentRunLogResponse(BaseModel):
|
|||||||
|
|
||||||
# ── Chatbot Journey ───────────────────────────────────────────────────
|
# ── Chatbot Journey ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class JourneyStartRequest(BaseModel):
|
||||||
|
agent_type: Literal["local", "cloud"]
|
||||||
|
agent_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class JourneyMessageRequest(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class JourneyResponse(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
message: str
|
||||||
|
done: bool
|
||||||
|
prompt_template: str | None = None
|
||||||
|
|||||||
@@ -1,941 +0,0 @@
|
|||||||
# Adiuva — Architettura Microservizi (MVP)
|
|
||||||
|
|
||||||
## Panoramica
|
|
||||||
|
|
||||||
Il monolite viene suddiviso in **4 servizi MVP** + un **API Gateway (Traefik)**, orchestrati con Docker Compose su un singolo VPS raggiungibile via Cloudflare.
|
|
||||||
|
|
||||||
> **Fuori dall'MVP**: Storage Service (S3/backup CRUD) e Plugin Service (marketplace). Verranno aggiunti come servizi indipendenti in una fase successiva.
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────────┐
|
|
||||||
│ Cloudflare │
|
|
||||||
│ (DNS + CDN) │
|
|
||||||
└──────┬───────┘
|
|
||||||
│ HTTPS / WSS
|
|
||||||
┌──────▼───────┐
|
|
||||||
│ Traefik │
|
|
||||||
│ API Gateway │
|
|
||||||
│ (routing, │
|
|
||||||
│ TLS, rate │
|
|
||||||
│ limiting) │
|
|
||||||
└──────┬───────┘
|
|
||||||
│
|
|
||||||
┌──────────┬───────────┼───────────┐
|
|
||||||
│ │ │ │
|
|
||||||
┌─────▼────┐ ┌───▼───┐ ┌────▼────┐ ┌────▼───┐
|
|
||||||
│ Auth │ │ Chat │ │ Agent │ │Billing │
|
|
||||||
│ Service │ │Service│ │ Service │ │Service │
|
|
||||||
└─────┬────┘ └───┬───┘ └────┬────┘ └────┬───┘
|
|
||||||
│ │ │ │
|
|
||||||
┌─────▼──────────▼──────────▼───────────▼────┐
|
|
||||||
│ Infrastruttura │
|
|
||||||
│ PostgreSQL │ Redis │ Qdrant │
|
|
||||||
└─────────────────────────────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 1. Suddivisione dei Servizi
|
|
||||||
|
|
||||||
### 1.1 Auth Service (`auth-service`)
|
|
||||||
|
|
||||||
**Responsabilità**: Registrazione, login, refresh token, profilo utente, encryption key.
|
|
||||||
|
|
||||||
| Endpoint originale | Metodo |
|
|
||||||
|---|---|
|
|
||||||
| `/api/v1/auth/register` | POST |
|
|
||||||
| `/api/v1/auth/login` | POST |
|
|
||||||
| `/api/v1/auth/refresh` | POST |
|
|
||||||
| `/api/v1/auth/me` | GET / PUT |
|
|
||||||
|
|
||||||
**Database**: Tabelle `users`, `refresh_tokens` (PostgreSQL condiviso, schema `auth`).
|
|
||||||
|
|
||||||
**Modifica chiave — JWT con RS256**:
|
|
||||||
Il monolite usa un `SECRET_KEY` simmetrico (HS256). Con i microservizi, passare a **RS256** (asimmetrico):
|
|
||||||
- L'Auth Service firma i JWT con la **chiave privata**.
|
|
||||||
- Tutti gli altri servizi verificano i JWT con la **chiave pubblica** senza mai contattare l'Auth Service.
|
|
||||||
- La chiave pubblica viene esposta via `GET /api/v1/auth/.well-known/jwks.json` oppure montata come volume condiviso.
|
|
||||||
|
|
||||||
```python
|
|
||||||
# auth-service/app/auth/jwt.py
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
||||||
from jose import jwt
|
|
||||||
|
|
||||||
PRIVATE_KEY = ... # Da env/secret
|
|
||||||
PUBLIC_KEY = ... # Derivata o da env
|
|
||||||
|
|
||||||
def create_access_token(user_id: str, tier: str) -> str:
|
|
||||||
return jwt.encode(
|
|
||||||
{"sub": user_id, "tier": tier, "exp": ...},
|
|
||||||
PRIVATE_KEY,
|
|
||||||
algorithm="RS256",
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
```python
|
|
||||||
# shared/auth.py (usato da tutti gli altri servizi)
|
|
||||||
from jose import jwt
|
|
||||||
|
|
||||||
PUBLIC_KEY = ... # Volume montato o fetched da JWKS endpoint
|
|
||||||
|
|
||||||
def verify_token(token: str) -> dict:
|
|
||||||
return jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
|
|
||||||
```
|
|
||||||
|
|
||||||
**Scaling**: 2 repliche sufficienti, stateless. Rate-limit dedicato su `/login` e `/register`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.2 Chat Service (`chat-service`) ⭐ Real-time
|
|
||||||
|
|
||||||
**Responsabilità**: WebSocket device connection, home chat, floating chat, memory middleware, streaming LLM responses verso il client.
|
|
||||||
|
|
||||||
Questo servizio gestisce la **connessione persistente** con l'app Electron e le interazioni **real-time** dell'utente (chat home, floating chat). È il proprietario della WebSocket.
|
|
||||||
|
|
||||||
| Endpoint | Tipo |
|
|
||||||
|---|---|
|
|
||||||
| `/api/v1/ws/device` | WebSocket (connessione persistente) |
|
|
||||||
| `/api/v1/chat` | POST (REST fallback) |
|
|
||||||
|
|
||||||
**Moduli inclusi**: `deep_agent`, `memory_middleware`, `ws_context`, `device_manager` (Redis-backed), `output_formatter`, `llm`, tutti gli agent tools (`task_agent`, `project_agent`, `note_agent`, `timeline_agent`).
|
|
||||||
|
|
||||||
**Perché separato dall'Agent Service**: Il Chat Service tiene la WebSocket aperta e risponde in tempo reale (streaming). Scalare aggiungendo repliche è semplice con sticky sessions + Redis pub/sub per il cross-instance routing dei tool_call.
|
|
||||||
|
|
||||||
**Scaling**: 2–N repliche. Sticky cookies per le WS + Redis per cross-instance.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.3 Agent Service (`agent-service`) ⭐ Batch
|
|
||||||
|
|
||||||
**Responsabilità**: Batch agent processing (directory scanning, file classification, entity extraction), agent setup journeys, agent configuration CRUD.
|
|
||||||
|
|
||||||
Questo servizio gestisce i processi **long-running** e **CPU-intensive**: scansione filesystem, classificazione file con LLM, estrazione entità in batch. Non possiede la WebSocket — comunica con il device dell'utente tramite **Redis pub/sub** passando per il Chat Service.
|
|
||||||
|
|
||||||
| Endpoint | Tipo |
|
|
||||||
|---|---|
|
|
||||||
| `/api/v1/agents/catalog` | GET |
|
|
||||||
| `/api/v1/agents/can-create` | POST |
|
|
||||||
| `/api/v1/agents/trigger` | POST |
|
|
||||||
| `/api/v1/agents/journey/start` | POST (o WS relay) |
|
|
||||||
| `/api/v1/agents/journey/message` | POST (o WS relay) |
|
|
||||||
|
|
||||||
**Moduli inclusi**: `agent_runner`, `agent_registry`, `filesystem_agent`, `llm`.
|
|
||||||
|
|
||||||
**Flusso tool-call cross-service** (l'Agent Service non ha la WS):
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────────┐ ┌──────────────┐ ┌──────────┐
|
|
||||||
│ Agent Service│ │ Redis │ │ Chat │
|
|
||||||
│ (batch run) │ │ │ │ Service │
|
|
||||||
│ │ │ │ │ (ha WS) │
|
|
||||||
│ 1. Needs to │ PUBLISH │ │ SUBSCRIBE │ │
|
|
||||||
│ read file ├───────────►│tool_call:u123├───────────►│ 2. Invia │
|
|
||||||
│ from │ │ │ │ al │
|
|
||||||
│ device │ │ │ │ device│
|
|
||||||
│ │ │ │ │ via WS│
|
|
||||||
│ │ SUBSCRIBE │ │ PUBLISH │ │
|
|
||||||
│ 4. Riceve ◄────────────┤tool_result:id│◄───────────┤ 3. Device│
|
|
||||||
│ risultato │ │ │ │ reply │
|
|
||||||
└──────────────┘ └──────────────┘ └──────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
**Scaling**: 1–N repliche. Completamente stateless, scala indipendentemente dalla chat. Ogni replica processa batch job diversi. Può essere scalato a 0 se non ci sono agent attivi (risparmio risorse).
|
|
||||||
|
|
||||||
**Vantaggio dello split**: Se 50 utenti triggerano agenti batch contemporaneamente, il Chat Service non ne risente — le risposte real-time rimangono veloci.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.4 Billing Service (`billing-service`)
|
|
||||||
|
|
||||||
**Responsabilità**: Stripe checkout, webhook, subscription management.
|
|
||||||
|
|
||||||
| Endpoint originale | Metodo |
|
|
||||||
|---|---|
|
|
||||||
| `/api/v1/billing/checkout` | POST |
|
|
||||||
| `/api/v1/billing/webhook` | POST |
|
|
||||||
| `/api/v1/billing/subscription` | GET / DELETE |
|
|
||||||
|
|
||||||
**Database**: Tabelle `subscriptions` (schema `billing`).
|
|
||||||
|
|
||||||
**Comunicazione inter-servizio**: Quando Stripe invia un webhook e il tier cambia, il Billing Service pubblica un evento su **Redis pub/sub** channel `tier_changed:{user_id}`. L'Auth Service aggiorna il campo `tier` nella tabella users. Al prossimo token refresh il JWT conterrà il tier aggiornato.
|
|
||||||
|
|
||||||
**Scaling**: 1 replica sufficiente. Basso traffico.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.5 Servizi esclusi dall'MVP
|
|
||||||
|
|
||||||
I seguenti servizi verranno aggiunti post-MVP come servizi indipendenti:
|
|
||||||
|
|
||||||
| Servizio | Responsabilità | Note |
|
|
||||||
|---|---|---|
|
|
||||||
| **Storage Service** | S3 blobs CRUD, vector ops, backup | Le funzionalità vector/embed possono restare nel Chat Service per il MVP |
|
|
||||||
| **Plugin Service** | Marketplace, install, revenue split | Feature non critica per il lancio |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 2. Tier Check — Dove e Come
|
|
||||||
|
|
||||||
Il tier dell'utente (free/pro/power/team) determina rate-limiting, quote e accesso a funzionalità. Con i microservizi, **ogni servizio controlla il tier autonomamente** senza chiamare l'Auth Service.
|
|
||||||
|
|
||||||
### Strategia: Tier nel JWT
|
|
||||||
|
|
||||||
L'Auth Service include il `tier` come claim nel JWT al momento del login/refresh:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"sub": "user_123",
|
|
||||||
"tier": "pro",
|
|
||||||
"exp": 1742515200,
|
|
||||||
"iat": 1742511600
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Ogni servizio:
|
|
||||||
1. Decodifica il JWT con la chiave pubblica (già lo fa per l'auth)
|
|
||||||
2. Legge `payload["tier"]` — **zero chiamate extra**
|
|
||||||
3. Applica le sue regole di enforcement localmente
|
|
||||||
|
|
||||||
```python
|
|
||||||
# shared/auth.py — dependency FastAPI condivisa
|
|
||||||
from fastapi import Depends, HTTPException, Request
|
|
||||||
from jose import jwt
|
|
||||||
|
|
||||||
PUBLIC_KEY = ...
|
|
||||||
|
|
||||||
class CurrentUser:
|
|
||||||
def __init__(self, user_id: str, tier: str):
|
|
||||||
self.user_id = user_id
|
|
||||||
self.tier = tier
|
|
||||||
|
|
||||||
async def get_current_user(request: Request) -> CurrentUser:
|
|
||||||
token = request.headers.get("Authorization", "").removeprefix("Bearer ")
|
|
||||||
payload = jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
|
|
||||||
return CurrentUser(user_id=payload["sub"], tier=payload["tier"])
|
|
||||||
|
|
||||||
def require_tier(*allowed_tiers: str):
|
|
||||||
"""Dependency che blocca se il tier non è tra quelli ammessi."""
|
|
||||||
async def check(user: CurrentUser = Depends(get_current_user)):
|
|
||||||
if user.tier not in allowed_tiers:
|
|
||||||
raise HTTPException(403, "Tier insufficient")
|
|
||||||
return user
|
|
||||||
return check
|
|
||||||
```
|
|
||||||
|
|
||||||
### Cosa succede quando il tier cambia (upgrade/downgrade)?
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────┐ Stripe webhook ┌──────────┐ tier_changed ┌──────────┐
|
|
||||||
│ Stripe │ ─────────────────►│ Billing │ ───────────────►│ Auth │
|
|
||||||
│ │ │ Service │ (Redis pub/sub) │ Service │
|
|
||||||
└──────────┘ └──────────┘ └────┬─────┘
|
|
||||||
│
|
|
||||||
UPDATE users
|
|
||||||
SET tier = 'power'
|
|
||||||
│
|
|
||||||
Al prossimo /refresh
|
|
||||||
il JWT conterrà tier='power'
|
|
||||||
```
|
|
||||||
|
|
||||||
**Latenza del cambio**: Il tier si propaga al prossimo token refresh (tipicamente 15–30 min, o il client può forzare un refresh immediato dopo il checkout). Per il billing webhook, il downgrade può essere forzato invalidando il refresh token su Redis → il client è obbligato a ri-autenticarsi.
|
|
||||||
|
|
||||||
### Dove si applica in ciascun servizio
|
|
||||||
|
|
||||||
| Servizio | Enforcement |
|
|
||||||
|---|---|
|
|
||||||
| **Auth Service** | Nessuno (è lui che scrive il tier) |
|
|
||||||
| **Chat Service** | Rate-limit per tier (req/min), quota messaggi |
|
|
||||||
| **Agent Service** | Max agent configs, max runs/day, max concurrent batches |
|
|
||||||
| **Billing Service** | Nessuno (gestisce i tier, non li consuma) |
|
|
||||||
|
|
||||||
### Rate-limit distribuito via Redis
|
|
||||||
|
|
||||||
Poiché ogni servizio ha le sue repliche, il rate-limiting deve essere **condiviso** via Redis:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# shared/middleware/rate_limit.py
|
|
||||||
import redis.asyncio as aioredis
|
|
||||||
|
|
||||||
class DistributedRateLimiter:
|
|
||||||
def __init__(self, redis: aioredis.Redis):
|
|
||||||
self._redis = redis
|
|
||||||
|
|
||||||
async def check(self, user_id: str, tier: str, service: str) -> bool:
|
|
||||||
limits = {"free": 20, "pro": 60, "power": 120, "team": 200}
|
|
||||||
max_req = limits.get(tier, 20)
|
|
||||||
key = f"rate:{service}:{user_id}"
|
|
||||||
|
|
||||||
pipe = self._redis.pipeline()
|
|
||||||
pipe.incr(key)
|
|
||||||
pipe.expire(key, 60)
|
|
||||||
count, _ = await pipe.execute()
|
|
||||||
|
|
||||||
return count <= max_req
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 3. WebSocket con Scaling Orizzontale — Il Problema Chiave
|
|
||||||
|
|
||||||
`DeviceConnectionManager` è un **singleton in-memory**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class DeviceConnectionManager:
|
|
||||||
def __init__(self):
|
|
||||||
self._connections: dict[str, DeviceConnection] = {} # ← In-memory!
|
|
||||||
```
|
|
||||||
|
|
||||||
Con N istanze del Chat Service, il device si connette a **una sola** istanza. Quando un'altra istanza deve inviare un `tool_call` a quel device (es. un agent trigger da un'API call), non trova la connessione.
|
|
||||||
|
|
||||||
### La soluzione: Redis Pub/Sub + Registry
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────────────────────────────────────────────────────────┐
|
|
||||||
│ Redis │
|
|
||||||
│ │
|
|
||||||
│ Hash: ws:connections │
|
|
||||||
│ user_123 → instance_A │
|
|
||||||
│ user_456 → instance_B │
|
|
||||||
│ │
|
|
||||||
│ Pub/Sub channels: │
|
|
||||||
│ tool_call:{user_id} → tool call payloads │
|
|
||||||
│ tool_result:{call_id} → tool result payloads │
|
|
||||||
│ stream:{user_id} → text_chunk streaming │
|
|
||||||
└──────────────────────────────────────────────────────────────┘
|
|
||||||
|
|
||||||
Instance A (ha WS di user_123) Instance B (deve chiamare tool su user_123)
|
|
||||||
┌───────────────────────┐ ┌───────────────────────┐
|
|
||||||
│ 1. Sottoscrive a │ │ 1. Lookup Redis Hash │
|
|
||||||
│ tool_call:user_123│ │ → user_123 è su A │
|
|
||||||
│ │ │ │
|
|
||||||
│ 2. Riceve tool_call │◄─────────│ 2. PUBLISH │
|
|
||||||
│ da Redis channel │ │ tool_call:user_123 │
|
|
||||||
│ │ │ {id, action, ...} │
|
|
||||||
│ 3. Invia al device │ │ │
|
|
||||||
│ via WS │ │ 4. SUBSCRIBE │
|
|
||||||
│ │ │ tool_result:{id} │
|
|
||||||
│ 4. Device risponde │ │ │
|
|
||||||
│ tool_result │──────────│► 5. Riceve risultato │
|
|
||||||
│ │ │ │
|
|
||||||
│ 5. PUBLISH │ │ │
|
|
||||||
│ tool_result:{id} │ │ │
|
|
||||||
└───────────────────────┘ └───────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### Implementazione: `RedisDeviceManager`
|
|
||||||
|
|
||||||
```python
|
|
||||||
# chat-service/app/core/device_manager.py
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import redis.asyncio as aioredis
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from fastapi import WebSocket
|
|
||||||
|
|
||||||
INSTANCE_ID = os.environ.get("INSTANCE_ID", os.urandom(8).hex())
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LocalConnection:
|
|
||||||
ws: WebSocket
|
|
||||||
device_id: str
|
|
||||||
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class RedisDeviceManager:
|
|
||||||
"""Device manager backed by Redis for cross-instance communication."""
|
|
||||||
|
|
||||||
def __init__(self, redis_url: str = "redis://redis:6379"):
|
|
||||||
self._redis = aioredis.from_url(redis_url)
|
|
||||||
self._pubsub = self._redis.pubsub()
|
|
||||||
self._local: dict[str, LocalConnection] = {} # Solo connessioni locali
|
|
||||||
self._remote_futures: dict[str, asyncio.Future[dict]] = {}
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
"""Avvia il listener Redis per tool_call in arrivo."""
|
|
||||||
asyncio.create_task(self._listen_tool_calls())
|
|
||||||
|
|
||||||
# ── Registrazione ──
|
|
||||||
|
|
||||||
async def register(self, user_id: str, device_id: str, ws: WebSocket):
|
|
||||||
# Registra localmente
|
|
||||||
self._local[user_id] = LocalConnection(ws=ws, device_id=device_id)
|
|
||||||
# Registra in Redis quale istanza ha la connessione
|
|
||||||
await self._redis.hset("ws:connections", user_id, INSTANCE_ID)
|
|
||||||
# Sottoscrivi ai tool_call per questo utente
|
|
||||||
await self._pubsub.subscribe(f"tool_call:{user_id}")
|
|
||||||
|
|
||||||
async def unregister(self, user_id: str):
|
|
||||||
conn = self._local.pop(user_id, None)
|
|
||||||
if conn:
|
|
||||||
for fut in conn.pending_calls.values():
|
|
||||||
if not fut.done():
|
|
||||||
fut.cancel()
|
|
||||||
await self._redis.hdel("ws:connections", user_id)
|
|
||||||
await self._pubsub.unsubscribe(f"tool_call:{user_id}")
|
|
||||||
|
|
||||||
# ── Presenza ──
|
|
||||||
|
|
||||||
async def is_online(self, user_id: str) -> bool:
|
|
||||||
return await self._redis.hexists("ws:connections", user_id)
|
|
||||||
|
|
||||||
# ── Tool-call round-trip (cross-instance) ──
|
|
||||||
|
|
||||||
async def execute_tool_call(self, user_id: str, payload: dict) -> dict:
|
|
||||||
"""
|
|
||||||
Invia un tool_call al device dell'utente.
|
|
||||||
Funziona sia che la WS sia locale che su un'altra istanza.
|
|
||||||
"""
|
|
||||||
call_id = payload["id"]
|
|
||||||
|
|
||||||
# Caso 1: connessione locale → invio diretto
|
|
||||||
if user_id in self._local:
|
|
||||||
conn = self._local[user_id]
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
fut: asyncio.Future[dict] = loop.create_future()
|
|
||||||
conn.pending_calls[call_id] = fut
|
|
||||||
await conn.ws.send_text(json.dumps({"type": "tool_call", **payload}))
|
|
||||||
return await asyncio.wait_for(fut, timeout=30.0)
|
|
||||||
|
|
||||||
# Caso 2: connessione remota → Redis pub/sub
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
fut = loop.create_future()
|
|
||||||
self._remote_futures[call_id] = fut
|
|
||||||
|
|
||||||
# Sottoscrivi al canale di risposta
|
|
||||||
result_channel = f"tool_result:{call_id}"
|
|
||||||
await self._pubsub.subscribe(result_channel)
|
|
||||||
|
|
||||||
# Pubblica il tool_call
|
|
||||||
await self._redis.publish(
|
|
||||||
f"tool_call:{user_id}",
|
|
||||||
json.dumps(payload),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return await asyncio.wait_for(fut, timeout=30.0)
|
|
||||||
finally:
|
|
||||||
self._remote_futures.pop(call_id, None)
|
|
||||||
await self._pubsub.unsubscribe(result_channel)
|
|
||||||
|
|
||||||
# ── Risoluzione tool_result (da WS locale) ──
|
|
||||||
|
|
||||||
def resolve_local(self, user_id: str, call_id: str, result: dict):
|
|
||||||
conn = self._local.get(user_id)
|
|
||||||
if conn:
|
|
||||||
fut = conn.pending_calls.pop(call_id, None)
|
|
||||||
if fut and not fut.done():
|
|
||||||
fut.set_result(result)
|
|
||||||
|
|
||||||
async def resolve_and_publish(self, user_id: str, call_id: str, result: dict):
|
|
||||||
"""Chiamato quando il device locale invia un tool_result."""
|
|
||||||
self.resolve_local(user_id, call_id, result)
|
|
||||||
# Pubblica anche su Redis per l'istanza remota che aspetta
|
|
||||||
await self._redis.publish(
|
|
||||||
f"tool_result:{call_id}",
|
|
||||||
json.dumps(result),
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Listener Redis ──
|
|
||||||
|
|
||||||
async def _listen_tool_calls(self):
|
|
||||||
"""Loop che ascolta i tool_call in arrivo da altre istanze."""
|
|
||||||
async for message in self._pubsub.listen():
|
|
||||||
if message["type"] != "message":
|
|
||||||
continue
|
|
||||||
channel = message["channel"]
|
|
||||||
if isinstance(channel, bytes):
|
|
||||||
channel = channel.decode()
|
|
||||||
|
|
||||||
data = json.loads(message["data"])
|
|
||||||
|
|
||||||
if channel.startswith("tool_call:"):
|
|
||||||
# Un'altra istanza vuole che inviamo un tool_call al nostro device
|
|
||||||
user_id = channel.split(":", 1)[1]
|
|
||||||
conn = self._local.get(user_id)
|
|
||||||
if conn:
|
|
||||||
await conn.ws.send_text(json.dumps({"type": "tool_call", **data}))
|
|
||||||
|
|
||||||
elif channel.startswith("tool_result:"):
|
|
||||||
# Risposta a un tool_call che abbiamo inviato tramite Redis
|
|
||||||
call_id = channel.split(":", 1)[1]
|
|
||||||
fut = self._remote_futures.pop(call_id, None)
|
|
||||||
if fut and not fut.done():
|
|
||||||
fut.set_result(data)
|
|
||||||
|
|
||||||
# ── Stream cross-instance ──
|
|
||||||
|
|
||||||
async def publish_stream_chunk(self, user_id: str, chunk: dict):
|
|
||||||
"""Pubblica un chunk di streaming su Redis (per REST→WS relay)."""
|
|
||||||
await self._redis.publish(f"stream:{user_id}", json.dumps(chunk))
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 4. Struttura Directory Proposta (MVP)
|
|
||||||
|
|
||||||
```
|
|
||||||
adiuva-api/
|
|
||||||
├── docker-compose.yml # Orchestrazione completa
|
|
||||||
├── docker-compose.dev.yml # Override per sviluppo locale
|
|
||||||
├── shared/ # Codice condiviso (montato come volume)
|
|
||||||
│ ├── auth.py # JWT verification (chiave pubblica)
|
|
||||||
│ ├── schemas.py # Pydantic schemas condivisi
|
|
||||||
│ ├── middleware/
|
|
||||||
│ │ ├── rate_limit.py # DistributedRateLimiter (Redis)
|
|
||||||
│ │ └── sanitizer.py
|
|
||||||
│ └── models/
|
|
||||||
│ └── base.py # SQLAlchemy base condivisa
|
|
||||||
│
|
|
||||||
├── auth-service/
|
|
||||||
│ ├── Dockerfile
|
|
||||||
│ ├── requirements.txt
|
|
||||||
│ └── app/
|
|
||||||
│ ├── main.py
|
|
||||||
│ ├── config.py
|
|
||||||
│ ├── db.py
|
|
||||||
│ ├── models.py # users, refresh_tokens
|
|
||||||
│ ├── routes/
|
|
||||||
│ │ └── auth.py
|
|
||||||
│ └── services/
|
|
||||||
│ ├── jwt_service.py # RS256 signing
|
|
||||||
│ └── user_service.py
|
|
||||||
│
|
|
||||||
├── chat-service/
|
|
||||||
│ ├── Dockerfile
|
|
||||||
│ ├── requirements.txt
|
|
||||||
│ └── app/
|
|
||||||
│ ├── main.py
|
|
||||||
│ ├── config.py
|
|
||||||
│ ├── db.py
|
|
||||||
│ ├── models.py # memory_*
|
|
||||||
│ ├── routes/
|
|
||||||
│ │ ├── device_ws.py # WS connection owner
|
|
||||||
│ │ └── chat.py # REST fallback
|
|
||||||
│ ├── core/
|
|
||||||
│ │ ├── device_manager.py # RedisDeviceManager
|
|
||||||
│ │ ├── deep_agent.py # Home + floating chat
|
|
||||||
│ │ ├── memory_middleware.py
|
|
||||||
│ │ ├── ws_context.py
|
|
||||||
│ │ ├── output_formatter.py
|
|
||||||
│ │ └── llm.py
|
|
||||||
│ └── agents/ # Tool definitions (used by deep_agent)
|
|
||||||
│ ├── task_agent.py
|
|
||||||
│ ├── project_agent.py
|
|
||||||
│ ├── note_agent.py
|
|
||||||
│ └── timeline_agent.py
|
|
||||||
│
|
|
||||||
├── agent-service/
|
|
||||||
│ ├── Dockerfile
|
|
||||||
│ ├── requirements.txt
|
|
||||||
│ └── app/
|
|
||||||
│ ├── main.py
|
|
||||||
│ ├── config.py
|
|
||||||
│ ├── db.py
|
|
||||||
│ ├── models.py # agent_run_logs, local/cloud_agent_configs
|
|
||||||
│ ├── routes/
|
|
||||||
│ │ ├── agents.py # catalog, can-create, trigger
|
|
||||||
│ │ └── agent_setup.py # journey start/message
|
|
||||||
│ ├── core/
|
|
||||||
│ │ ├── agent_runner.py # Batch classify → process
|
|
||||||
│ │ ├── agent_registry.py
|
|
||||||
│ │ ├── redis_executor.py # execute_on_client via Redis pub/sub
|
|
||||||
│ │ └── llm.py
|
|
||||||
│ └── agents/
|
|
||||||
│ ├── task_agent.py # Tool definitions (batch context)
|
|
||||||
│ ├── project_agent.py
|
|
||||||
│ ├── note_agent.py
|
|
||||||
│ ├── timeline_agent.py
|
|
||||||
│ └── filesystem_agent.py
|
|
||||||
│
|
|
||||||
├── billing-service/
|
|
||||||
│ ├── Dockerfile
|
|
||||||
│ ├── requirements.txt
|
|
||||||
│ └── app/
|
|
||||||
│ ├── main.py
|
|
||||||
│ ├── config.py
|
|
||||||
│ ├── db.py
|
|
||||||
│ ├── models.py # subscriptions
|
|
||||||
│ ├── routes/
|
|
||||||
│ │ └── billing.py
|
|
||||||
│ └── services/
|
|
||||||
│ ├── stripe_service.py
|
|
||||||
│ └── tier_manager.py
|
|
||||||
│
|
|
||||||
└── infra/
|
|
||||||
├── traefik/
|
|
||||||
│ └── traefik.yml
|
|
||||||
├── keys/
|
|
||||||
│ ├── jwt_private.pem # Solo auth-service
|
|
||||||
│ └── jwt_public.pem # Tutti i servizi
|
|
||||||
└── alembic/ # Migrazioni condivise o per-servizio
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 5. Docker Compose — Configurazione MVP
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# docker-compose.yml
|
|
||||||
|
|
||||||
services:
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# API Gateway
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
traefik:
|
|
||||||
image: traefik:v3.2
|
|
||||||
command:
|
|
||||||
- "--api.insecure=true"
|
|
||||||
- "--providers.docker=true"
|
|
||||||
- "--providers.docker.exposedbydefault=false"
|
|
||||||
- "--entrypoints.web.address=:80"
|
|
||||||
- "--entrypoints.websecure.address=:443"
|
|
||||||
- "--entrypoints.web.http.redirections.entrypoint.to=websecure"
|
|
||||||
ports:
|
|
||||||
- "80:80"
|
|
||||||
- "443:443"
|
|
||||||
- "8080:8080" # Dashboard Traefik (disabilitare in prod)
|
|
||||||
volumes:
|
|
||||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
|
||||||
- ./infra/certs:/certs:ro
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Auth Service (2 repliche)
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
auth-service:
|
|
||||||
build: ./auth-service
|
|
||||||
deploy:
|
|
||||||
replicas: 2
|
|
||||||
env_file: .env
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
REDIS_URL: redis://redis:6379
|
|
||||||
JWT_PRIVATE_KEY_FILE: /run/secrets/jwt_private_key
|
|
||||||
SERVICE_NAME: auth
|
|
||||||
secrets:
|
|
||||||
- jwt_private_key
|
|
||||||
- jwt_public_key
|
|
||||||
labels:
|
|
||||||
- "traefik.enable=true"
|
|
||||||
- "traefik.http.routers.auth.rule=PathPrefix(`/api/v1/auth`)"
|
|
||||||
- "traefik.http.services.auth.loadbalancer.server.port=8000"
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Chat Service — Real-time WS + Chat (scalabile)
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
chat-service:
|
|
||||||
build: ./chat-service
|
|
||||||
deploy:
|
|
||||||
replicas: 2
|
|
||||||
env_file: .env
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
REDIS_URL: redis://redis:6379
|
|
||||||
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
|
||||||
SERVICE_NAME: chat
|
|
||||||
secrets:
|
|
||||||
- jwt_public_key
|
|
||||||
labels:
|
|
||||||
- "traefik.enable=true"
|
|
||||||
# REST chat endpoint
|
|
||||||
- "traefik.http.routers.chat.rule=PathPrefix(`/api/v1/chat`)"
|
|
||||||
- "traefik.http.services.chat.loadbalancer.server.port=8000"
|
|
||||||
# WebSocket route con sticky session
|
|
||||||
- "traefik.http.routers.ws.rule=PathPrefix(`/api/v1/ws`)"
|
|
||||||
- "traefik.http.routers.ws.service=chat-ws"
|
|
||||||
- "traefik.http.services.chat-ws.loadbalancer.server.port=8000"
|
|
||||||
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.name=ws_affinity"
|
|
||||||
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.httpOnly=true"
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Agent Service — Batch processing (scalabile indipendentemente)
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
agent-service:
|
|
||||||
build: ./agent-service
|
|
||||||
deploy:
|
|
||||||
replicas: 2
|
|
||||||
env_file: .env
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
REDIS_URL: redis://redis:6379
|
|
||||||
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
|
||||||
SERVICE_NAME: agent
|
|
||||||
secrets:
|
|
||||||
- jwt_public_key
|
|
||||||
labels:
|
|
||||||
- "traefik.enable=true"
|
|
||||||
- "traefik.http.routers.agents.rule=PathPrefix(`/api/v1/agents`)"
|
|
||||||
- "traefik.http.services.agents.loadbalancer.server.port=8000"
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Billing Service (1 replica)
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
billing-service:
|
|
||||||
build: ./billing-service
|
|
||||||
deploy:
|
|
||||||
replicas: 1
|
|
||||||
env_file: .env
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
REDIS_URL: redis://redis:6379
|
|
||||||
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
|
||||||
SERVICE_NAME: billing
|
|
||||||
secrets:
|
|
||||||
- jwt_public_key
|
|
||||||
labels:
|
|
||||||
- "traefik.enable=true"
|
|
||||||
- "traefik.http.routers.billing.rule=PathPrefix(`/api/v1/billing`)"
|
|
||||||
- "traefik.http.services.billing.loadbalancer.server.port=8000"
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Infrastruttura
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
db:
|
|
||||||
image: pgvector/pgvector:pg16
|
|
||||||
environment:
|
|
||||||
POSTGRES_USER: postgres
|
|
||||||
POSTGRES_PASSWORD: postgres
|
|
||||||
POSTGRES_DB: adiuva
|
|
||||||
volumes:
|
|
||||||
- postgres_data:/var/lib/postgresql/data
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
|
||||||
interval: 5s
|
|
||||||
timeout: 5s
|
|
||||||
retries: 5
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
redis:
|
|
||||||
image: redis:7-alpine
|
|
||||||
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
|
|
||||||
volumes:
|
|
||||||
- redis_data:/data
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "redis-cli", "ping"]
|
|
||||||
interval: 5s
|
|
||||||
timeout: 3s
|
|
||||||
retries: 5
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
qdrant:
|
|
||||||
image: qdrant/qdrant:latest
|
|
||||||
volumes:
|
|
||||||
- qdrant_data:/qdrant/storage
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
secrets:
|
|
||||||
jwt_private_key:
|
|
||||||
file: ./infra/keys/jwt_private.pem
|
|
||||||
jwt_public_key:
|
|
||||||
file: ./infra/keys/jwt_public.pem
|
|
||||||
|
|
||||||
volumes:
|
|
||||||
postgres_data:
|
|
||||||
redis_data:
|
|
||||||
qdrant_data:
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 6. Configurazione Cloudflare + VPS
|
|
||||||
|
|
||||||
### 6.1 DNS
|
|
||||||
|
|
||||||
```
|
|
||||||
api.tuodominio.com → A record → IP del VPS
|
|
||||||
→ Proxy: ON (orange cloud)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6.2 Cloudflare Settings
|
|
||||||
|
|
||||||
| Setting | Valore | Motivo |
|
|
||||||
|---------|--------|--------|
|
|
||||||
| SSL/TLS mode | **Full (Strict)** | Cloudflare ↔ VPS con certificato valido |
|
|
||||||
| WebSocket | **ON** | Necessario per `/api/v1/ws/device` |
|
|
||||||
| Proxy timeout | **100s** (Enterprise) o default | Le LLM calls possono durare 30s+ |
|
|
||||||
| Under Attack Mode | Off (attivare se necessario) | |
|
|
||||||
|
|
||||||
### 6.3 TLS sul VPS
|
|
||||||
|
|
||||||
Due opzioni:
|
|
||||||
- **Opzione A (consigliata)**: Cloudflare Origin Certificate → montato in Traefik
|
|
||||||
- **Opzione B**: Let's Encrypt via Traefik (con DNS challenge Cloudflare)
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# traefik.yml — con Cloudflare Origin Certificate
|
|
||||||
entryPoints:
|
|
||||||
websecure:
|
|
||||||
address: ":443"
|
|
||||||
|
|
||||||
tls:
|
|
||||||
certificates:
|
|
||||||
- certFile: /certs/origin.pem
|
|
||||||
keyFile: /certs/origin-key.pem
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6.4 Rete VPS
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# UFW firewall — solo Cloudflare può raggiungere le porte 80/443
|
|
||||||
# https://www.cloudflare.com/ips/
|
|
||||||
ufw default deny incoming
|
|
||||||
ufw allow from 173.245.48.0/20 to any port 443
|
|
||||||
ufw allow from 103.21.244.0/22 to any port 443
|
|
||||||
# ... (tutti gli IP range di Cloudflare)
|
|
||||||
ufw allow ssh
|
|
||||||
ufw enable
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 7. Comunicazione Inter-Servizio
|
|
||||||
|
|
||||||
### 7.1 Redis Pub/Sub — Event Bus
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────┐ tier_changed:user_123 ┌──────────┐
|
|
||||||
│ Billing │ ────────────────────────► │ Auth │
|
|
||||||
│ Service │ │ Service │
|
|
||||||
└──────────┘ └──────────┘
|
|
||||||
|
|
||||||
┌──────────┐ tool_call:user_123 ┌──────────┐
|
|
||||||
│ Agent │ ────────────────────────► │ Chat │
|
|
||||||
│ Service │ │ Service │
|
|
||||||
│ (batch) │ ◄────────────────────────│ (ha WS) │
|
|
||||||
└──────────┘ tool_result:{call_id} └──────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### 7.2 Health Checks e Service Discovery
|
|
||||||
|
|
||||||
Traefik gestisce automaticamente il service discovery via Docker labels. I servizi non devono conoscersi tra loro — comunicano solo via:
|
|
||||||
- **Redis pub/sub** (tool-call cross-instance, tier events)
|
|
||||||
- **Redis hash** (stato condiviso: `ws:connections`, rate-limit counters)
|
|
||||||
- **PostgreSQL** (dati persistenti condivisi)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 8. Piano di Migrazione Incrementale (MVP)
|
|
||||||
|
|
||||||
### Fase 1 — Preparazione (nel monolite attuale)
|
|
||||||
1. Aggiungere Redis al `docker-compose.yml` attuale
|
|
||||||
2. Migrare JWT da HS256 → RS256 (backward-compatible: accetta entrambi per un periodo)
|
|
||||||
3. Implementare `RedisDeviceManager` come drop-in replacement del singleton in-memory
|
|
||||||
4. Estrarre `shared/` con auth verification, schemas, middleware
|
|
||||||
|
|
||||||
### Fase 2 — Auth Service (primo split)
|
|
||||||
1. Estrarre `auth.py` routes + models in `auth-service/`
|
|
||||||
2. Verificare che i JWT firmati da `auth-service` vengano validati dal monolite
|
|
||||||
3. Aggiungere Traefik e routare `/api/v1/auth/*` al nuovo servizio
|
|
||||||
4. Il monolite continua a servire tutto il resto
|
|
||||||
|
|
||||||
### Fase 3 — Billing Service
|
|
||||||
1. Estrarre billing routes, Stripe service, tier manager
|
|
||||||
2. Configurare Redis pub/sub per `tier_changed` events
|
|
||||||
3. Routare via Traefik
|
|
||||||
|
|
||||||
### Fase 4 — Split Chat + Agent (il più delicato)
|
|
||||||
1. Il monolite residuo contiene WS + chat + agents
|
|
||||||
2. Separare Agent Service: estrarre `agent_runner`, `agent_registry`, `agent_setup`, route `/agents/*`
|
|
||||||
3. Implementare `redis_executor.py` nell'Agent Service per tool-call via Redis
|
|
||||||
4. Il Chat Service resta proprietario della WS e sottoscrive i canali `tool_call:{user_id}`
|
|
||||||
5. Testare: trigger agent dall'Agent Service → tool_call via Redis → Chat Service → WS → device → risposta
|
|
||||||
|
|
||||||
### Fase 5 — Scaling test
|
|
||||||
1. Scalare Chat Service a 2 repliche, verificare sticky sessions
|
|
||||||
2. Scalare Agent Service a 2 repliche, verificare batch processing distribuito
|
|
||||||
3. Monitoring (Prometheus + Grafana) per ogni servizio
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 9. Monitoraggio e Logging
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# Aggiungere al docker-compose.yml
|
|
||||||
|
|
||||||
prometheus:
|
|
||||||
image: prom/prometheus:latest
|
|
||||||
volumes:
|
|
||||||
- ./infra/prometheus/prometheus.yml:/etc/prometheus/prometheus.yml
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
grafana:
|
|
||||||
image: grafana/grafana:latest
|
|
||||||
ports:
|
|
||||||
- "3000:3000"
|
|
||||||
volumes:
|
|
||||||
- grafana_data:/var/lib/grafana
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
loki:
|
|
||||||
image: grafana/loki:latest
|
|
||||||
restart: unless-stopped
|
|
||||||
```
|
|
||||||
|
|
||||||
Ogni servizio espone `/metrics` (Prometheus) e scrive log strutturati (JSON) raccolti da Loki.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 10. Sizing VPS Minimo Consigliato (MVP)
|
|
||||||
|
|
||||||
| Componente | CPU | RAM | Note |
|
|
||||||
|---|---|---|---|
|
|
||||||
| Traefik | 0.25 | 128MB | |
|
|
||||||
| Auth Service ×2 | 0.25 ×2 | 128MB ×2 | Stateless, leggero |
|
|
||||||
| Chat Service ×2 | 1.0 ×2 | 1GB ×2 | WS + streaming LLM |
|
|
||||||
| Agent Service ×2 | 0.75 ×2 | 512MB ×2 | Batch LLM, CPU-bound |
|
|
||||||
| Billing Service | 0.25 | 128MB | |
|
|
||||||
| PostgreSQL | 1.0 | 1GB | |
|
|
||||||
| Redis | 0.25 | 256MB | |
|
|
||||||
| Qdrant | 0.5 | 512MB | |
|
|
||||||
| **Totale MVP** | **~5.5 vCPU** | **~5 GB** | |
|
|
||||||
|
|
||||||
**Raccomandazione**: VPS con **8 vCPU / 16 GB RAM** per avere margine. Hetzner CPX41 (~€30/mese) o equivalente. Senza Storage/Plugin si risparmia ~1 vCPU e 512MB rispetto alla versione completa.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Riepilogo Architettura MVP
|
|
||||||
|
|
||||||
| Servizio | Repliche | Proprietario di |
|
|
||||||
|---|---|---|
|
|
||||||
| **Traefik** | 1 | Routing, TLS, sticky sessions |
|
|
||||||
| **Auth Service** | 2 | JWT RS256, registrazione, login, profilo |
|
|
||||||
| **Chat Service** | 2–N | WebSocket, home/floating chat, streaming |
|
|
||||||
| **Agent Service** | 2–N | Batch processing, directory scan, agent setup |
|
|
||||||
| **Billing Service** | 1 | Stripe, subscriptions, tier management |
|
|
||||||
|
|
||||||
| Decisione | Scelta | Motivazione |
|
|
||||||
|---|---|---|
|
|
||||||
| API Gateway | Traefik | Nativo Docker, WebSocket support, service discovery automatico |
|
|
||||||
| JWT | RS256 (asimmetrico) | Verifica distribuita senza contattare Auth Service |
|
|
||||||
| Tier check | Claim nel JWT | Ogni servizio verifica localmente, zero roundtrip |
|
|
||||||
| WebSocket scaling | Redis pub/sub + sticky cookies | Cross-instance tool-call routing |
|
|
||||||
| Chat ↔ Agent split | Servizi separati | Batch CPU-bound non impatta real-time chat |
|
|
||||||
| Agent → Device comms | Redis pub/sub via Chat Service | Agent non possiede la WS, usa un relay |
|
|
||||||
| Rate limiting | Redis contatori distribuiti | Sliding window condivisa tra repliche |
|
|
||||||
| Database | PostgreSQL condiviso | Semplicità MVP; split DB futuro facile |
|
|
||||||
| TLS | Cloudflare Origin Certificate | Zero maintenance |
|
|
||||||
| Orchestrazione | Docker Compose | Sufficiente per un singolo VPS |
|
|
||||||
| Storage / Plugin | Post-MVP | Non critici per il lancio |
|
|
||||||
@@ -4,6 +4,8 @@ gunicorn>=22.0.0
|
|||||||
langchain>=0.3.0
|
langchain>=0.3.0
|
||||||
langchain-openai>=0.3.0
|
langchain-openai>=0.3.0
|
||||||
langchain-litellm>=0.1.0
|
langchain-litellm>=0.1.0
|
||||||
|
langgraph>=0.3.0
|
||||||
|
deepagents>=0.4.10
|
||||||
litellm>=1.50.0
|
litellm>=1.50.0
|
||||||
pydantic>=2.10.0
|
pydantic>=2.10.0
|
||||||
pydantic-settings>=2.7.0
|
pydantic-settings>=2.7.0
|
||||||
@@ -32,8 +34,4 @@ google-auth-oauthlib>=1.2.0
|
|||||||
google-auth-httplib2>=0.2.0
|
google-auth-httplib2>=0.2.0
|
||||||
msal>=1.28.0
|
msal>=1.28.0
|
||||||
cryptography>=42.0.0
|
cryptography>=42.0.0
|
||||||
langfuse>=2.0.0
|
|
||||||
beautifulsoup4>=4.12.0
|
|
||||||
lxml>=5.0.0
|
|
||||||
PyYAML>=6.0.0
|
|
||||||
ruff>=0.8.0
|
ruff>=0.8.0
|
||||||
|
|||||||
@@ -6,21 +6,26 @@ a per-test session, and a FastAPI ``TestClient`` wired to use it.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator, Generator
|
from collections.abc import AsyncGenerator, Generator
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import boto3
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
|
from moto import mock_aws
|
||||||
from sqlalchemy import StaticPool, event
|
from sqlalchemy import StaticPool, event
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.db import Base, get_session
|
from app.db import Base, get_session
|
||||||
from app.main import app
|
from app.main import app
|
||||||
from app.models import Subscription, User
|
from app.models import Plugin, Subscription, User
|
||||||
|
|
||||||
# ── Fixed test user IDs (one per tier) ───────────────────────────────
|
# ── Fixed test user IDs (one per tier) ───────────────────────────────
|
||||||
|
|
||||||
@@ -104,6 +109,79 @@ def client(db_session: AsyncSession) -> Generator[TestClient, None, None]: # n
|
|||||||
app.dependency_overrides.pop(get_session, None)
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Seed data helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SEED_PLUGINS = [
|
||||||
|
Plugin(
|
||||||
|
id="plugin-github-sync",
|
||||||
|
name="GitHub Sync",
|
||||||
|
description="Sync tasks with GitHub Issues and pull requests.",
|
||||||
|
version="1.0.0",
|
||||||
|
author_name="Adiuva",
|
||||||
|
category="productivity",
|
||||||
|
price_cents=0,
|
||||||
|
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
status="approved",
|
||||||
|
s3_package_key="plugins/plugin-github-sync/1.0.0/package.zip",
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
),
|
||||||
|
Plugin(
|
||||||
|
id="plugin-slack-notify",
|
||||||
|
name="Slack Notifier",
|
||||||
|
description="Post task and timeline updates to Slack channels.",
|
||||||
|
version="1.2.0",
|
||||||
|
author_name="Adiuva",
|
||||||
|
category="communication",
|
||||||
|
price_cents=499,
|
||||||
|
permissions=json.dumps(["read:tasks", "read:timelines"]),
|
||||||
|
status="approved",
|
||||||
|
s3_package_key="plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
),
|
||||||
|
Plugin(
|
||||||
|
id="plugin-time-tracker",
|
||||||
|
name="Time Tracker",
|
||||||
|
description="Track time spent on tasks with automatic reporting.",
|
||||||
|
version="0.9.1",
|
||||||
|
author_name="Third Party",
|
||||||
|
category="productivity",
|
||||||
|
price_cents=999,
|
||||||
|
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
status="approved",
|
||||||
|
s3_package_key="plugins/plugin-time-tracker/0.9.1/package.zip",
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def seed_plugins(db_session: AsyncSession) -> list[Plugin]:
|
||||||
|
"""Insert the 3 default approved plugins and return them."""
|
||||||
|
plugins = []
|
||||||
|
for template in _SEED_PLUGINS:
|
||||||
|
p = Plugin(
|
||||||
|
id=template.id,
|
||||||
|
name=template.name,
|
||||||
|
description=template.description,
|
||||||
|
version=template.version,
|
||||||
|
author_name=template.author_name,
|
||||||
|
category=template.category,
|
||||||
|
price_cents=template.price_cents,
|
||||||
|
permissions=template.permissions,
|
||||||
|
status=template.status,
|
||||||
|
s3_package_key=template.s3_package_key,
|
||||||
|
install_count=template.install_count,
|
||||||
|
avg_rating=template.avg_rating,
|
||||||
|
)
|
||||||
|
db_session.add(p)
|
||||||
|
plugins.append(p)
|
||||||
|
await db_session.commit()
|
||||||
|
return plugins
|
||||||
|
|
||||||
|
|
||||||
# ── JWT helpers ──────────────────────────────────────────────────────
|
# ── JWT helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -134,21 +212,24 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st
|
|||||||
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
||||||
|
|
||||||
|
|
||||||
# ── CLI options ───────────────────────────────────────────────────────
|
# ── S3 mock fixture ──────────────────────────────────────────────────
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
S3_TEST_BUCKET = "test-bucket"
|
||||||
parser.addoption(
|
S3_TEST_REGION = "us-east-1"
|
||||||
"--preprocess-dir",
|
|
||||||
default=None,
|
|
||||||
help="Override fixture folder for preprocessor tests (must contain cases.yaml + data/)",
|
@pytest.fixture
|
||||||
)
|
def s3_bucket():
|
||||||
parser.addoption(
|
"""Create a mocked S3 bucket via moto and patch BlobStore settings."""
|
||||||
"--runner-dir",
|
with mock_aws():
|
||||||
default=None,
|
os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing")
|
||||||
help="Override fixture folder for agent_runner_v2 eval tests (must contain cases.yaml + data/)",
|
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing")
|
||||||
)
|
os.environ.setdefault("AWS_DEFAULT_REGION", S3_TEST_REGION)
|
||||||
parser.addoption(
|
client = boto3.client("s3", region_name=S3_TEST_REGION)
|
||||||
"--journey-dir",
|
client.create_bucket(Bucket=S3_TEST_BUCKET)
|
||||||
default=None,
|
with patch("app.storage.blob_store.settings") as mock_settings:
|
||||||
help="Override fixture folder for journey_v2 eval tests (must contain cases.yaml + data/)",
|
mock_settings.S3_BUCKET = S3_TEST_BUCKET
|
||||||
)
|
mock_settings.S3_REGION = S3_TEST_REGION
|
||||||
|
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
||||||
|
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
||||||
|
yield S3_TEST_BUCKET
|
||||||
|
|||||||
86
tests/fixtures/agent_runner_v2/cases.yaml
vendored
86
tests/fixtures/agent_runner_v2/cases.yaml
vendored
@@ -1,86 +0,0 @@
|
|||||||
# Agent Runner V2 — eval test cases (Step 2, requires real LLM)
|
|
||||||
#
|
|
||||||
# Each case drives one parametrized `test_eval_runner` invocation.
|
|
||||||
#
|
|
||||||
# Keys
|
|
||||||
# ----
|
|
||||||
# id: str unique identifier shown in pytest output
|
|
||||||
# description: str human-readable label
|
|
||||||
# file: str filename inside data/
|
|
||||||
# file_path: str path reported to the executor (affects project-matching via filename)
|
|
||||||
# projects: [alpha|beta] symbolic project names resolved by the test helper
|
|
||||||
#
|
|
||||||
# Optional pre-existing records (dedup tests)
|
|
||||||
# existing_tasks: list of {id, title, status, priority}
|
|
||||||
# existing_notes: list of {id, title, content}
|
|
||||||
# existing_timelines: list of {id, title, date}
|
|
||||||
#
|
|
||||||
# Assertions (one or more)
|
|
||||||
# expect_insert: <table> at least 1 insert row in this table (tasks|notes|timelines)
|
|
||||||
# expect_no_insert: true zero inserts in any table
|
|
||||||
# expect_project_id: <id> any insert must carry this projectId
|
|
||||||
# expect_dedup: true task inserts == 0 OR task updates >= 1 (dedup check)
|
|
||||||
#
|
|
||||||
# Langfuse
|
|
||||||
# score_name: str observation score name
|
|
||||||
|
|
||||||
- id: "2.1"
|
|
||||||
description: "Action email → create_task"
|
|
||||||
file: email_action.html
|
|
||||||
file_path: /emails/ProjectAlpha_action.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_insert: tasks
|
|
||||||
score_name: runner.email_to_task
|
|
||||||
|
|
||||||
- id: "2.2"
|
|
||||||
description: "Informational email → create_note"
|
|
||||||
file: email_info.html
|
|
||||||
file_path: /emails/ProjectAlpha_info.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_insert: notes
|
|
||||||
score_name: runner.email_to_note
|
|
||||||
|
|
||||||
- id: "2.3"
|
|
||||||
description: "Email with meeting date → create_timeline"
|
|
||||||
file: email_date.html
|
|
||||||
file_path: /emails/ProjectAlpha_kickoff.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_insert: timelines
|
|
||||||
score_name: runner.email_to_timeline
|
|
||||||
|
|
||||||
- id: "2.4"
|
|
||||||
description: "Filename contains project name → correct project assigned"
|
|
||||||
file: email_action.html
|
|
||||||
file_path: /emails/ProjectAlpha_report.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_project_id: proj-alpha
|
|
||||||
score_name: runner.project_filename
|
|
||||||
|
|
||||||
- id: "2.5"
|
|
||||||
description: "Email body mentions project → correct project assigned"
|
|
||||||
file: email_action.html
|
|
||||||
file_path: /emails/email_001.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_project_id: proj-alpha
|
|
||||||
score_name: runner.project_content
|
|
||||||
|
|
||||||
- id: "2.6"
|
|
||||||
description: "Newsletter + global rule no-project → no creates"
|
|
||||||
file: email_no_project.html
|
|
||||||
file_path: /emails/newsletter.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_no_insert: true
|
|
||||||
score_name: runner.no_project
|
|
||||||
|
|
||||||
- id: "2.7"
|
|
||||||
description: "Existing task with same title → dedup (update not create)"
|
|
||||||
file: email_action.html
|
|
||||||
file_path: /emails/ProjectAlpha_followup.html
|
|
||||||
projects: [alpha]
|
|
||||||
existing_tasks:
|
|
||||||
- id: task-existing
|
|
||||||
title: Fix the login bug
|
|
||||||
status: todo
|
|
||||||
priority: medium
|
|
||||||
expect_dedup: true
|
|
||||||
score_name: runner.dedup
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
<html><head></head><body>
|
|
||||||
<p><b>From:</b> boss@company.com</p>
|
|
||||||
<p><b>To:</b> dev@company.com</p>
|
|
||||||
<p><b>Subject:</b> Fix the login bug</p>
|
|
||||||
<p><b>Date:</b> 2026-04-07</p>
|
|
||||||
<p>Hi,<br>Please fix the login bug in Project Alpha by Friday. High priority!</p>
|
|
||||||
</body></html>
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
<html><head></head><body>
|
|
||||||
<p><b>From:</b> pm@company.com</p>
|
|
||||||
<p><b>Subject:</b> Project Alpha kick-off meeting</p>
|
|
||||||
<p>The kick-off meeting for Project Alpha is scheduled for 2026-04-15 at 10:00.</p>
|
|
||||||
</body></html>
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
<html><head></head><body>
|
|
||||||
<p><b>From:</b> pm@company.com</p>
|
|
||||||
<p><b>To:</b> team@company.com</p>
|
|
||||||
<p><b>Subject:</b> FYI: New policy for Project Alpha</p>
|
|
||||||
<p>Just a heads-up that starting next week all code reviews must be done
|
|
||||||
within 24 hours for Project Alpha. No action needed from you now.</p>
|
|
||||||
</body></html>
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
<html><head></head><body>
|
|
||||||
<p><b>From:</b> newsletter@ads.com</p>
|
|
||||||
<p><b>Subject:</b> Weekly newsletter</p>
|
|
||||||
<p>Check out our latest deals on electronics!</p>
|
|
||||||
</body></html>
|
|
||||||
87
tests/fixtures/journey_v2/cases.yaml
vendored
87
tests/fixtures/journey_v2/cases.yaml
vendored
@@ -1,87 +0,0 @@
|
|||||||
# Journey V2 eval test cases — Step 4
|
|
||||||
#
|
|
||||||
# Each case simulates a complete journey session:
|
|
||||||
# 1. handle_journey_start is called with directory + data_types
|
|
||||||
# 2. handle_journey_message is called for each entry in user_messages
|
|
||||||
# 3. Assertions are evaluated on the final reply
|
|
||||||
#
|
|
||||||
# directory_files: list of {path, content_file} — content_file is relative to data/
|
|
||||||
#
|
|
||||||
# Assertion keys:
|
|
||||||
# expect_question: true → first reply must contain "?"
|
|
||||||
# expect_done: true → final reply must have done=True
|
|
||||||
# expect_valid_config: true → agent_config must be parseable as AgentConfig with content_types > 0
|
|
||||||
# expect_content_type_id: <str> → AgentConfig.content_types must contain an entry with this id
|
|
||||||
# expect_extraction_contains: <str> → first content_type extraction_prompt must contain this word
|
|
||||||
# expect_global_rules: true → AgentConfig.global_rules must be non-empty
|
|
||||||
|
|
||||||
- id: "4.1"
|
|
||||||
description: "Journey start explores directory, first reply contains a question"
|
|
||||||
directory: "/test/emails"
|
|
||||||
data_types: ["tasks", "notes", "timelines"]
|
|
||||||
directory_files:
|
|
||||||
- path: "/test/emails/outlook_export_2024.html"
|
|
||||||
content_file: "email_action.html"
|
|
||||||
user_messages: []
|
|
||||||
score_name: "journey.start"
|
|
||||||
expect_question: true
|
|
||||||
|
|
||||||
- id: "4.2"
|
|
||||||
description: "Full 3-turn conversation produces a valid AgentConfig JSON"
|
|
||||||
directory: "/test/emails"
|
|
||||||
data_types: ["tasks", "notes", "timelines"]
|
|
||||||
directory_files:
|
|
||||||
- path: "/test/emails/email_backup.html"
|
|
||||||
content_file: "email_action.html"
|
|
||||||
user_messages:
|
|
||||||
- "These are email exports from Outlook in HTML format"
|
|
||||||
- "Create tasks for emails with direct action requests, notes for informational emails"
|
|
||||||
- "Yes, that looks correct. No other rules."
|
|
||||||
score_name: "journey.valid_json"
|
|
||||||
expect_done: true
|
|
||||||
expect_valid_config: true
|
|
||||||
|
|
||||||
- id: "4.3"
|
|
||||||
description: "Journey detects email_html content type from directory exploration"
|
|
||||||
directory: "/test/emails"
|
|
||||||
data_types: ["tasks", "notes"]
|
|
||||||
directory_files:
|
|
||||||
- path: "/test/emails/message.html"
|
|
||||||
content_file: "email_action.html"
|
|
||||||
user_messages:
|
|
||||||
- "HTML email backups from my mail client, exported from Outlook"
|
|
||||||
- "Create tasks from emails that contain assignments or direct action items"
|
|
||||||
- "Correct, no other rules needed"
|
|
||||||
score_name: "journey.detect_email"
|
|
||||||
expect_done: true
|
|
||||||
expect_content_type_id: "email_html"
|
|
||||||
|
|
||||||
- id: "4.4"
|
|
||||||
description: "Custom user rule (only notes, no tasks) reflected in extraction_prompt"
|
|
||||||
directory: "/test/emails"
|
|
||||||
data_types: ["notes"]
|
|
||||||
directory_files:
|
|
||||||
- path: "/test/emails/email.html"
|
|
||||||
content_file: "email_info.html"
|
|
||||||
user_messages:
|
|
||||||
- "HTML emails from my work inbox"
|
|
||||||
- "Create only notes from all emails — I do not want tasks or timelines to be created"
|
|
||||||
- "Yes, exactly"
|
|
||||||
score_name: "journey.custom_rules"
|
|
||||||
expect_done: true
|
|
||||||
expect_extraction_contains: "note"
|
|
||||||
|
|
||||||
- id: "4.5"
|
|
||||||
description: "Global rule (no project = no entity) appears in AgentConfig.global_rules"
|
|
||||||
directory: "/test/emails"
|
|
||||||
data_types: ["tasks", "notes"]
|
|
||||||
directory_files:
|
|
||||||
- path: "/test/emails/email.html"
|
|
||||||
content_file: "email_action.html"
|
|
||||||
user_messages:
|
|
||||||
- "Email backups from Outlook"
|
|
||||||
- "Create tasks from action request emails, notes from informational emails"
|
|
||||||
- "If the email cannot be matched to any project, do not create any entity at all"
|
|
||||||
score_name: "journey.global_rules"
|
|
||||||
expect_done: true
|
|
||||||
expect_global_rules: true
|
|
||||||
23
tests/fixtures/journey_v2/data/email_action.html
vendored
23
tests/fixtures/journey_v2/data/email_action.html
vendored
@@ -1,23 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<title>Email: Fix the login bug</title>
|
|
||||||
<style>body { font-family: Arial; } .header { color: #666; }</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="header">
|
|
||||||
<p><strong>From:</strong> boss@company.com</p>
|
|
||||||
<p><strong>To:</strong> dev@company.com</p>
|
|
||||||
<p><strong>Subject:</strong> Fix the login bug</p>
|
|
||||||
<p><strong>Date:</strong> Mon, 7 Apr 2026 09:15:00 +0000</p>
|
|
||||||
</div>
|
|
||||||
<div class="body">
|
|
||||||
<p>Hi,</p>
|
|
||||||
<p>Please fix the login bug in Project Alpha as soon as possible.
|
|
||||||
Users are reporting that they can't log in with their Google accounts.
|
|
||||||
This is blocking the whole team. Please resolve it by Friday.</p>
|
|
||||||
<p>Thanks,<br>Boss</p>
|
|
||||||
</div>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
23
tests/fixtures/journey_v2/data/email_info.html
vendored
23
tests/fixtures/journey_v2/data/email_info.html
vendored
@@ -1,23 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<title>Email: New policy update</title>
|
|
||||||
<style>body { font-family: Arial; }</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="header">
|
|
||||||
<p><strong>From:</strong> hr@company.com</p>
|
|
||||||
<p><strong>To:</strong> all@company.com</p>
|
|
||||||
<p><strong>Subject:</strong> FYI: New remote work policy effective May 1</p>
|
|
||||||
<p><strong>Date:</strong> Tue, 8 Apr 2026 10:00:00 +0000</p>
|
|
||||||
</div>
|
|
||||||
<div class="body">
|
|
||||||
<p>Hi everyone,</p>
|
|
||||||
<p>Just a heads-up that starting May 1, 2026 the company will be moving to
|
|
||||||
a hybrid work model. You will be expected to come into the office at least
|
|
||||||
two days per week. More details will follow in the employee handbook.</p>
|
|
||||||
<p>Best,<br>HR Team</p>
|
|
||||||
</div>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
68
tests/fixtures/preprocessors/cases.yaml
vendored
68
tests/fixtures/preprocessors/cases.yaml
vendored
@@ -1,68 +0,0 @@
|
|||||||
# Preprocessor test cases
|
|
||||||
#
|
|
||||||
# detect: <expected_type> → chiama detect_content_type(filename, content)
|
|
||||||
# process: <content_type> → chiama preprocess(content_type, content)
|
|
||||||
#
|
|
||||||
# Sorgente: file: <nome in data/> oppure generate: binary_noise
|
|
||||||
#
|
|
||||||
# Assertions piatte (solo per process):
|
|
||||||
# no_html: true clean_text senza tag HTML
|
|
||||||
# min_chars: N len(clean_text) >= N
|
|
||||||
# ratio_lt: F len(clean) / len(raw) < F
|
|
||||||
# has_meta: [k, ...] chiavi presenti in metadata
|
|
||||||
# contains: str | [str] substring(s) presenti in clean_text
|
|
||||||
# excludes: str | [str] substring(s) assenti da clean_text
|
|
||||||
# content_type: str result.content_type == questo valore
|
|
||||||
|
|
||||||
- id: "1.1"
|
|
||||||
file: email_action.html
|
|
||||||
detect: email_html
|
|
||||||
|
|
||||||
- id: "1.2"
|
|
||||||
file: generic_page.html
|
|
||||||
detect: generic_html
|
|
||||||
|
|
||||||
- id: "1.3"
|
|
||||||
file: notes.txt
|
|
||||||
detect: plain_text
|
|
||||||
|
|
||||||
- id: "1.4"
|
|
||||||
file: archive.xyz
|
|
||||||
generate: binary_noise
|
|
||||||
detect: unknown
|
|
||||||
|
|
||||||
- id: "1.5"
|
|
||||||
file: email_action.html
|
|
||||||
process: email_html
|
|
||||||
no_html: true
|
|
||||||
min_chars: 50
|
|
||||||
ratio_lt: 0.8
|
|
||||||
|
|
||||||
- id: "1.6"
|
|
||||||
file: email_action.html
|
|
||||||
process: email_html
|
|
||||||
has_meta: [subject, from]
|
|
||||||
|
|
||||||
- id: "1.7"
|
|
||||||
file: email_thread.html
|
|
||||||
process: email_html
|
|
||||||
contains: "Sure, I'll handle the deploy"
|
|
||||||
excludes: "Let's plan the deploy"
|
|
||||||
|
|
||||||
- id: "1.8"
|
|
||||||
file: email_single.html
|
|
||||||
process: email_html
|
|
||||||
contains: "deploy is done"
|
|
||||||
|
|
||||||
- id: "1.9"
|
|
||||||
file: email_heavy.html
|
|
||||||
process: email_html
|
|
||||||
no_html: true
|
|
||||||
min_chars: 30
|
|
||||||
excludes: [border-collapse, font-size]
|
|
||||||
|
|
||||||
- id: "1.10"
|
|
||||||
file: fallback.txt
|
|
||||||
process: unknown
|
|
||||||
min_chars: 1
|
|
||||||
content_type: unknown
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<title>Fix the login bug</title>
|
|
||||||
<style>
|
|
||||||
body { font-family: Arial, sans-serif; color: #333; margin: 0; padding: 20px; }
|
|
||||||
.header { background: #f5f5f5; padding: 10px; border-bottom: 1px solid #ddd; }
|
|
||||||
.body { padding: 20px; }
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="header">
|
|
||||||
<p><strong>From:</strong> boss@company.com</p>
|
|
||||||
<p><strong>To:</strong> dev@company.com</p>
|
|
||||||
<p><strong>Subject:</strong> Fix the login bug</p>
|
|
||||||
<p><strong>Date:</strong> Mon, 7 Apr 2026 09:00:00 +0200</p>
|
|
||||||
</div>
|
|
||||||
<div class="body">
|
|
||||||
<p>Hi,</p>
|
|
||||||
<p>Please fix the login bug by Friday. It is blocking the release.</p>
|
|
||||||
<p>Priority: high. Let me know if you need anything.</p>
|
|
||||||
<p>Thanks,<br>Boss</p>
|
|
||||||
</div>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<style>
|
|
||||||
table { border-collapse: collapse; width: 100%; max-width: 600px; margin: 0 auto; }
|
|
||||||
td { padding: 8px 12px; border: 1px solid #dddddd; font-size: 12px; color: #444444; }
|
|
||||||
.header-row { background-color: #003366; color: #ffffff; font-weight: bold; }
|
|
||||||
.label-col { background-color: #f0f0f0; width: 80px; font-weight: bold; }
|
|
||||||
.footer-row { font-size: 10px; color: #999999; text-align: center; }
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body bgcolor="#eeeeee">
|
|
||||||
<center>
|
|
||||||
<table cellpadding="0" cellspacing="0">
|
|
||||||
<tr class="header-row">
|
|
||||||
<td colspan="2">Company Internal Update</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td class="label-col">From:</td>
|
|
||||||
<td>newsletter@corp.com</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td class="label-col">Subject:</td>
|
|
||||||
<td>Q1 Results Update</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td class="label-col">Date:</td>
|
|
||||||
<td>Apr 7, 2026</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td colspan="2">
|
|
||||||
<table width="100%" cellpadding="10">
|
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<p style="font-size:14px; font-weight:bold;">Dear Team,</p>
|
|
||||||
<p>Q1 results are in. Revenue up 15% year-over-year.</p>
|
|
||||||
<p>Please review the attached report and share any feedback by EOW.</p>
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
</table>
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
<tr class="footer-row">
|
|
||||||
<td colspan="2">Confidential — do not forward outside the company.</td>
|
|
||||||
</tr>
|
|
||||||
</table>
|
|
||||||
</center>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html><body>
|
|
||||||
<p><strong>From:</strong> alice@co.com</p>
|
|
||||||
<p><strong>To:</strong> team@co.com</p>
|
|
||||||
<p><strong>Subject:</strong> Quick update</p>
|
|
||||||
<p><strong>Date:</strong> Tue, 7 Apr 2026 10:30:00 +0200</p>
|
|
||||||
<p>The deploy is done. Everything looks good. No issues so far.</p>
|
|
||||||
</body></html>
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html><body>
|
|
||||||
<div class="message-latest">
|
|
||||||
<p><strong>From:</strong> alice@co.com</p>
|
|
||||||
<p><strong>Subject:</strong> Re: Re: Deploy plan</p>
|
|
||||||
<p>Sure, I'll handle the deploy.</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<p>On Mon, Apr 6, 2026 at 3:00 PM, Bob <bob@co.com> wrote:</p>
|
|
||||||
<blockquote>
|
|
||||||
<p>From: bob@co.com</p>
|
|
||||||
<p>Can you handle the deploy?</p>
|
|
||||||
<p>On Sun, Apr 5, 2026 at 1:00 PM, Alice <alice@co.com> wrote:</p>
|
|
||||||
<blockquote>
|
|
||||||
<p>From: alice@co.com</p>
|
|
||||||
<p>Let's plan the deploy for Monday.</p>
|
|
||||||
<p>On Sat, Apr 4, 2026 at 11:00 AM, Charlie <charlie@co.com> wrote:</p>
|
|
||||||
<blockquote>
|
|
||||||
<p>From: charlie@co.com</p>
|
|
||||||
<p>We need to schedule the deploy. What day works?</p>
|
|
||||||
</blockquote>
|
|
||||||
</blockquote>
|
|
||||||
</blockquote>
|
|
||||||
</body></html>
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
random text content without any structure
|
|
||||||
line two with some words
|
|
||||||
line three and more content here
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<title>My Web App</title>
|
|
||||||
<link rel="stylesheet" href="styles.css">
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<nav>
|
|
||||||
<a href="/">Home</a>
|
|
||||||
<a href="/about">About</a>
|
|
||||||
<a href="/contact">Contact</a>
|
|
||||||
</nav>
|
|
||||||
<main>
|
|
||||||
<header>
|
|
||||||
<h1>Welcome to My App</h1>
|
|
||||||
</header>
|
|
||||||
<article>
|
|
||||||
<p>This is a generic web page with no email headers.</p>
|
|
||||||
<p>It has navigation, main content, and a footer.</p>
|
|
||||||
</article>
|
|
||||||
<section>
|
|
||||||
<h2>Features</h2>
|
|
||||||
<ul>
|
|
||||||
<li>Fast</li>
|
|
||||||
<li>Reliable</li>
|
|
||||||
<li>Secure</li>
|
|
||||||
</ul>
|
|
||||||
</section>
|
|
||||||
</main>
|
|
||||||
<footer>
|
|
||||||
<p>© 2026 My App</p>
|
|
||||||
</footer>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
15
tests/fixtures/preprocessors/data/notes.txt
vendored
15
tests/fixtures/preprocessors/data/notes.txt
vendored
@@ -1,15 +0,0 @@
|
|||||||
Meeting notes - April 7, 2026
|
|
||||||
|
|
||||||
Attendees: Alice, Bob, Charlie
|
|
||||||
|
|
||||||
Discussion points:
|
|
||||||
- Deploy scheduled for Friday
|
|
||||||
- Bug fix for login must be completed by Thursday
|
|
||||||
- Review Q1 numbers before EOW
|
|
||||||
|
|
||||||
Action items:
|
|
||||||
- Alice: fix login bug
|
|
||||||
- Bob: prepare deploy checklist
|
|
||||||
- Charlie: send Q1 report
|
|
||||||
|
|
||||||
Next meeting: April 14, 2026
|
|
||||||
@@ -10,13 +10,13 @@ Coverage:
|
|||||||
- run_local_agent — file-read timeout path
|
- run_local_agent — file-read timeout path
|
||||||
- run_local_agent — LLM extraction error path
|
- run_local_agent — LLM extraction error path
|
||||||
- run_cloud_agent — stub returns error immediately
|
- run_cloud_agent — stub returns error immediately
|
||||||
- trigger_pending_runs — skipped when config is client-owned
|
- trigger_pending_runs — overdue local + cloud dispatched
|
||||||
- trigger_pending_runs — non-overdue skipped
|
- trigger_pending_runs — non-overdue skipped
|
||||||
- trigger_pending_runs — device_id filter for local agents
|
- trigger_pending_runs — device_id filter for local agents
|
||||||
|
|
||||||
Integration:
|
Integration:
|
||||||
- POST /agents/can-create — billing eligibility check
|
- POST /agents/{id}/run — 404 on unknown agent
|
||||||
- POST /agents/trigger — creates run log + dispatches background task
|
- POST /agents/{id}/run — creates run log + dispatches background task
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -373,7 +373,7 @@ async def test_run_local_agent_happy_path():
|
|||||||
assert kwargs["items_processed"] == 1
|
assert kwargs["items_processed"] == 1
|
||||||
assert kwargs["items_created"] == 1
|
assert kwargs["items_created"] == 1
|
||||||
assert kwargs["errors"] == []
|
assert kwargs["errors"] == []
|
||||||
assert kwargs["update_config_last_run"] is False
|
assert kwargs["update_config_last_run"] is True
|
||||||
|
|
||||||
# Verify agent_run frame was sent.
|
# Verify agent_run frame was sent.
|
||||||
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
||||||
@@ -690,11 +690,31 @@ async def test_finalize_run_updates_cloud_config_last_run_at():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_pending_runs_no_overdue():
|
async def test_trigger_pending_runs_no_overdue():
|
||||||
"""Pending-run scan is skipped because agent config is client-owned."""
|
"""If no agents are overdue trigger_pending_runs does nothing."""
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
config = _make_local_config()
|
||||||
|
config.last_run_at = datetime.now(timezone.utc) - timedelta(minutes=30) # ran 30m ago
|
||||||
|
config.schedule_cron = "0 */6 * * *" # every 6h — not due yet
|
||||||
|
|
||||||
|
mock_db_result_local = MagicMock()
|
||||||
|
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
||||||
|
|
||||||
|
mock_db_result_cloud = MagicMock()
|
||||||
|
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||||
|
|
||||||
mgr = _make_manager()
|
mgr = _make_manager()
|
||||||
|
|
||||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||||
|
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_ctx.execute = AsyncMock(
|
||||||
|
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||||
|
)
|
||||||
|
mock_session_factory.return_value = mock_ctx
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
mock_run.assert_not_called()
|
mock_run.assert_not_called()
|
||||||
@@ -702,11 +722,31 @@ async def test_trigger_pending_runs_no_overdue():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_pending_runs_device_id_filter():
|
async def test_trigger_pending_runs_device_id_filter():
|
||||||
"""Device filtering is no longer backend-managed in pending runs."""
|
"""Local agents are only triggered for the matching device_id."""
|
||||||
|
# The DB query already filters by device_id, so we verify the SELECT
|
||||||
|
# includes the device_id filter by checking that a config bound to a
|
||||||
|
# different device is never dispatched.
|
||||||
|
#
|
||||||
|
# Since trigger_pending_runs queries with device_id == "dev-001",
|
||||||
|
# simulate the DB returning an empty list (as it would for a mismatch).
|
||||||
|
mock_db_result_local = MagicMock()
|
||||||
|
mock_db_result_local.scalars.return_value.all.return_value = [] # no match
|
||||||
|
|
||||||
|
mock_db_result_cloud = MagicMock()
|
||||||
|
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||||
|
|
||||||
mgr = _make_manager(device_id="dev-001")
|
mgr = _make_manager(device_id="dev-001")
|
||||||
|
|
||||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||||
|
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_ctx.execute = AsyncMock(
|
||||||
|
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||||
|
)
|
||||||
|
mock_session_factory.return_value = mock_ctx
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
mock_run.assert_not_called()
|
mock_run.assert_not_called()
|
||||||
@@ -714,18 +754,56 @@ async def test_trigger_pending_runs_device_id_filter():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_pending_runs_dispatches_overdue():
|
async def test_trigger_pending_runs_dispatches_overdue():
|
||||||
"""No pending runs are dispatched by backend after config deprecation."""
|
"""Overdue local agent triggers run_local_agent sequentially."""
|
||||||
|
config = _make_local_config() # last_run_at=None → always overdue
|
||||||
|
|
||||||
|
mock_db_result_local = MagicMock()
|
||||||
|
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
||||||
|
|
||||||
|
mock_db_result_cloud = MagicMock()
|
||||||
|
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||||
|
|
||||||
mgr = _make_manager()
|
mgr = _make_manager()
|
||||||
|
|
||||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
call_order: list[str] = []
|
||||||
|
|
||||||
|
async def _mock_run_local(user_id, cfg, run_log, device_mgr):
|
||||||
|
call_order.append("run_local")
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||||
|
patch("app.core.agent_runner.run_local_agent", side_effect=_mock_run_local):
|
||||||
|
# First call: query configs. Subsequent calls: create run_log.
|
||||||
|
mock_query_ctx = AsyncMock()
|
||||||
|
mock_query_ctx.__aenter__ = AsyncMock(return_value=mock_query_ctx)
|
||||||
|
mock_query_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_query_ctx.execute = AsyncMock(
|
||||||
|
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||||
|
)
|
||||||
|
|
||||||
|
run_log_obj = AgentRunLog(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
agent_id=config.id,
|
||||||
|
agent_type="local",
|
||||||
|
user_id=_FREE_UID,
|
||||||
|
status="running",
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
mock_insert_ctx = AsyncMock()
|
||||||
|
mock_insert_ctx.__aenter__ = AsyncMock(return_value=mock_insert_ctx)
|
||||||
|
mock_insert_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_insert_ctx.add = MagicMock()
|
||||||
|
mock_insert_ctx.commit = AsyncMock()
|
||||||
|
mock_insert_ctx.refresh = AsyncMock(side_effect=lambda obj: None)
|
||||||
|
|
||||||
|
mock_session_factory.side_effect = [mock_query_ctx, mock_insert_ctx]
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
mock_run.assert_not_called()
|
assert call_order == ["run_local"]
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Integration: POST /agents/can-create and /agents/trigger
|
# Integration: POST /agents/{id}/run
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -742,67 +820,50 @@ def _override_db(db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_can_create_agent_allows_when_under_limit(client):
|
async def test_trigger_run_unknown_agent(client):
|
||||||
"""POST /agents/can-create returns allowed=True when under tier limit."""
|
"""POST /agents/{id}/run returns 404 for unknown agent id."""
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
"/api/v1/agents/can-create",
|
f"/api/v1/agents/{uuid.uuid4()}/run",
|
||||||
json={"active_agents": 0},
|
headers=auth_header("power"),
|
||||||
headers=auth_header("free"),
|
|
||||||
)
|
)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 404
|
||||||
body = resp.json()
|
|
||||||
assert body["allowed"] is True
|
|
||||||
assert body["tier"] == "free"
|
|
||||||
assert body["active_agents"] == 0
|
|
||||||
assert body["limit"] == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_can_create_agent_denies_when_at_limit(client):
|
|
||||||
"""POST /agents/can-create returns allowed=False at free-tier limit."""
|
|
||||||
resp = client.post(
|
|
||||||
"/api/v1/agents/can-create",
|
|
||||||
json={"active_agents": 2},
|
|
||||||
headers=auth_header("free"),
|
|
||||||
)
|
|
||||||
assert resp.status_code == 200
|
|
||||||
body = resp.json()
|
|
||||||
assert body["allowed"] is False
|
|
||||||
assert body["limit"] == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
||||||
"""POST /agents/trigger creates a local run log and dispatches background task."""
|
"""POST /agents/{id}/run creates a run log and dispatches a background task."""
|
||||||
dispatched: list[tuple[str, str]] = []
|
# Create the local agent config in the DB.
|
||||||
|
config = LocalAgentConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=TEST_USER_IDS["power"],
|
||||||
|
device_id="dev-001",
|
||||||
|
name="My Agent",
|
||||||
|
directory_paths=["/home/user/docs"],
|
||||||
|
data_types=["tasks"],
|
||||||
|
prompt_template="Extract tasks.",
|
||||||
|
file_extensions=[".txt"],
|
||||||
|
schedule_cron="0 */6 * * *",
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
db_session.add(config)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
dispatched: list = []
|
||||||
|
|
||||||
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
||||||
dispatched.append((user_id, cfg.id))
|
dispatched.append((user_id, cfg.id))
|
||||||
|
|
||||||
def _fake_create_task(coro):
|
|
||||||
coro.close()
|
|
||||||
return MagicMock()
|
|
||||||
|
|
||||||
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
||||||
|
patch("app.api.routes.agents.run_cloud_agent", new_callable=AsyncMock), \
|
||||||
patch("asyncio.create_task") as mock_create_task:
|
patch("asyncio.create_task") as mock_create_task:
|
||||||
mock_create_task.side_effect = _fake_create_task
|
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
"/api/v1/agents/trigger",
|
f"/api/v1/agents/{config.id}/run",
|
||||||
json={
|
|
||||||
"directory": "/home/user/docs",
|
|
||||||
"what_to_extract": ["task", "note"],
|
|
||||||
"actions_by_type": {"task": ["add", "update"], "note": ["add"]},
|
|
||||||
"batch_interval": "0 */6 * * *",
|
|
||||||
"custom_agent_prompt": "Extract tasks and notes.",
|
|
||||||
"active_agents": 0,
|
|
||||||
},
|
|
||||||
headers=auth_header("power"),
|
headers=auth_header("power"),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert resp.status_code == 202
|
assert resp.status_code == 202
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
assert isinstance(data["agent_id"], str)
|
assert data["agent_id"] == config.id
|
||||||
assert data["agent_id"]
|
|
||||||
assert data["status"] == "running"
|
assert data["status"] == "running"
|
||||||
assert data["agent_type"] == "local"
|
assert data["agent_type"] == "local"
|
||||||
|
|
||||||
|
|||||||
@@ -1,432 +0,0 @@
|
|||||||
"""Tests for Local Agent V2 runner (Step 2).
|
|
||||||
|
|
||||||
Covers the unified per-file flow:
|
|
||||||
Phase A — detect + preprocess (Python, zero LLM)
|
|
||||||
Phase B — single LLM call with tools (classify + extract + create)
|
|
||||||
|
|
||||||
Fixture-based eval tests (2.1–2.7)
|
|
||||||
-----------------------------------
|
|
||||||
Cases are defined in tests/fixtures/agent_runner_v2/cases.yaml.
|
|
||||||
Email HTML files live in tests/fixtures/agent_runner_v2/data/.
|
|
||||||
Use --runner-dir to point at a custom folder (same structure required).
|
|
||||||
|
|
||||||
Unit tests (no LLM)
|
|
||||||
--------------------
|
|
||||||
2.8 items_created count → items_created == N create_* calls
|
|
||||||
2.9 Device offline → status=error
|
|
||||||
2.10 Empty file → items_processed=0, status=success
|
|
||||||
|
|
||||||
Run:
|
|
||||||
pytest tests/test_agent_runner_v2.py -v
|
|
||||||
pytest tests/test_agent_runner_v2.py -v -k "2_9 or 2_10 or 2_8" # unit only
|
|
||||||
pytest tests/test_agent_runner_v2.py -v -k "eval" # LLM evals only
|
|
||||||
pytest tests/test_agent_runner_v2.py -v --runner-dir /path/to/dir # custom fixtures
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from app.core.agent_runner import (
|
|
||||||
_format_metadata,
|
|
||||||
_format_projects,
|
|
||||||
_get_extraction_rules,
|
|
||||||
_get_no_match_behavior,
|
|
||||||
_is_overdue,
|
|
||||||
run_local_agent,
|
|
||||||
)
|
|
||||||
from app.core.device_manager import DeviceConnectionManager
|
|
||||||
from app.core.langfuse_client import get_langfuse
|
|
||||||
from app.models import AgentRunLog, LocalAgentConfig
|
|
||||||
from tests.conftest import TEST_USER_IDS
|
|
||||||
|
|
||||||
# ── Constants ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
_USER_ID = TEST_USER_IDS["power"]
|
|
||||||
|
|
||||||
_DEFAULT_FIXTURE_DIR = Path(__file__).parent / "fixtures" / "agent_runner_v2"
|
|
||||||
|
|
||||||
_AGENT_CONFIG = {
|
|
||||||
"content_types": [
|
|
||||||
{
|
|
||||||
"id": "email_html",
|
|
||||||
"label": "Email HTML",
|
|
||||||
"detection_hint": "HTML file with From/To/Subject headers",
|
|
||||||
"preprocessing": "email_html",
|
|
||||||
"extraction_prompt": (
|
|
||||||
"If the email contains a direct action request or task assignment → create a task. "
|
|
||||||
"If the email contains informational content, updates, or FYI → create a note. "
|
|
||||||
"If the email mentions a specific date for a meeting or deadline → create a timeline entry."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"global_rules": [
|
|
||||||
"Se il file non è riconducibile a nessun progetto, non creare alcuna entità."
|
|
||||||
],
|
|
||||||
"data_types": ["tasks", "notes", "timelines"],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Canonical project definitions, referenced symbolically in cases.yaml.
|
|
||||||
_PROJECTS: dict[str, dict] = {
|
|
||||||
"alpha": {"id": "proj-alpha", "name": "Project Alpha", "status": "active"},
|
|
||||||
"beta": {"id": "proj-beta", "name": "Project Beta", "status": "active"},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Fixture loading ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _fixtures_dir(config) -> Path:
|
|
||||||
override = config.getoption("--runner-dir")
|
|
||||||
return Path(override) if override else _DEFAULT_FIXTURE_DIR
|
|
||||||
|
|
||||||
|
|
||||||
def _load_cases(config) -> list[dict]:
|
|
||||||
return yaml.safe_load(
|
|
||||||
(_fixtures_dir(config) / "cases.yaml").read_text(encoding="utf-8")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _read_case_file(case: dict, data_dir: Path) -> str:
|
|
||||||
return (data_dir / case["file"]).read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_projects(entries: list[str | dict]) -> list[dict]:
|
|
||||||
"""Resolve project list from YAML: symbolic names and/or inline dicts."""
|
|
||||||
result = []
|
|
||||||
for entry in entries:
|
|
||||||
if isinstance(entry, str):
|
|
||||||
if entry in _PROJECTS:
|
|
||||||
result.append(_PROJECTS[entry])
|
|
||||||
elif isinstance(entry, dict):
|
|
||||||
result.append(entry)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# ── pytest_generate_tests — parametrize eval tests from YAML ─────────────
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
|
||||||
if "runner_case" not in metafunc.fixturenames:
|
|
||||||
return
|
|
||||||
cases = _load_cases(metafunc.config)
|
|
||||||
metafunc.parametrize("runner_case", cases, ids=[c["id"] for c in cases])
|
|
||||||
|
|
||||||
|
|
||||||
# ── Test helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _make_config(
|
|
||||||
agent_config: dict | None = None,
|
|
||||||
directory: str = "/emails",
|
|
||||||
device_id: str = "dev-001",
|
|
||||||
) -> LocalAgentConfig:
|
|
||||||
return LocalAgentConfig(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=_USER_ID,
|
|
||||||
device_id=device_id,
|
|
||||||
name="Test V2 Agent",
|
|
||||||
directory_paths=[directory],
|
|
||||||
data_types=["tasks", "notes", "timelines"],
|
|
||||||
prompt_template="",
|
|
||||||
agent_config=agent_config or _AGENT_CONFIG,
|
|
||||||
file_extensions=[".html", ".eml"],
|
|
||||||
schedule_cron="0 */6 * * *",
|
|
||||||
enabled=True,
|
|
||||||
last_run_at=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_run_log(agent_id: str) -> AgentRunLog:
|
|
||||||
return AgentRunLog(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
agent_id=agent_id,
|
|
||||||
agent_type="local",
|
|
||||||
user_id=_USER_ID,
|
|
||||||
status="running",
|
|
||||||
started_at=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_manager(online: bool = True) -> DeviceConnectionManager:
|
|
||||||
mgr = DeviceConnectionManager()
|
|
||||||
if online:
|
|
||||||
ws = MagicMock()
|
|
||||||
ws.send_text = AsyncMock()
|
|
||||||
mgr.register(_USER_ID, "dev-001", ws)
|
|
||||||
return mgr
|
|
||||||
|
|
||||||
|
|
||||||
def _make_executor(
|
|
||||||
file_path: str,
|
|
||||||
file_content: str,
|
|
||||||
projects: list[dict] | None = None,
|
|
||||||
existing_tasks: list[dict] | None = None,
|
|
||||||
existing_notes: list[dict] | None = None,
|
|
||||||
existing_timelines: list[dict] | None = None,
|
|
||||||
) -> tuple[Any, list[dict]]:
|
|
||||||
"""Return (async_executor, captured_calls).
|
|
||||||
|
|
||||||
The executor handles all ``execute_on_client`` payloads:
|
|
||||||
directory listing, file reading, project/entity fetching, and CRUD.
|
|
||||||
"""
|
|
||||||
calls: list[dict] = []
|
|
||||||
_projects = projects if projects is not None else list(_PROJECTS.values())
|
|
||||||
|
|
||||||
async def _executor(payload: dict) -> dict:
|
|
||||||
action = payload.get("action", "")
|
|
||||||
table = payload.get("table", "")
|
|
||||||
data = payload.get("data") or {}
|
|
||||||
calls.append({"action": action, "table": table, "data": data})
|
|
||||||
|
|
||||||
if action == "list_directory":
|
|
||||||
return {"entries": [{"type": "file", "path": file_path}]}
|
|
||||||
|
|
||||||
if action == "get_file_metadata":
|
|
||||||
return {"modifiedAt": None}
|
|
||||||
|
|
||||||
if action == "read_file_content":
|
|
||||||
return {"content": file_content}
|
|
||||||
|
|
||||||
if action == "select":
|
|
||||||
if table == "projects":
|
|
||||||
return {"rows": _projects}
|
|
||||||
if table == "tasks":
|
|
||||||
return {"rows": existing_tasks or []}
|
|
||||||
if table == "notes":
|
|
||||||
return {"rows": existing_notes or []}
|
|
||||||
if table == "timelines":
|
|
||||||
return {"rows": existing_timelines or []}
|
|
||||||
return {"rows": []}
|
|
||||||
|
|
||||||
if action == "insert":
|
|
||||||
return {"row": {"id": str(uuid.uuid4()), **data}}
|
|
||||||
|
|
||||||
if action == "update":
|
|
||||||
return {"success": True}
|
|
||||||
|
|
||||||
return {}
|
|
||||||
|
|
||||||
return _executor, calls
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit: helper functions ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_format_projects_empty():
|
|
||||||
assert "(no projects" in _format_projects([])
|
|
||||||
|
|
||||||
|
|
||||||
def test_format_projects_with_data():
|
|
||||||
result = _format_projects([_PROJECTS["alpha"]])
|
|
||||||
assert "proj-alpha" in result
|
|
||||||
assert "Project Alpha" in result
|
|
||||||
|
|
||||||
|
|
||||||
def test_format_metadata_empty():
|
|
||||||
assert _format_metadata({}) == ""
|
|
||||||
|
|
||||||
|
|
||||||
def test_format_metadata_email():
|
|
||||||
meta = {"subject": "Fix bug", "from": "boss@co.com", "date": "2026-04-07"}
|
|
||||||
result = _format_metadata(meta)
|
|
||||||
assert "Fix bug" in result
|
|
||||||
assert "boss@co.com" in result
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_extraction_rules_match():
|
|
||||||
rules = _get_extraction_rules(_AGENT_CONFIG, "email_html")
|
|
||||||
assert "task" in rules.lower()
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_extraction_rules_fallback():
|
|
||||||
rules = _get_extraction_rules(_AGENT_CONFIG, "plain_text")
|
|
||||||
assert "extract" in rules.lower()
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_no_match_behavior_from_global_rules():
|
|
||||||
behavior = _get_no_match_behavior(_AGENT_CONFIG)
|
|
||||||
assert behavior # non-empty
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_no_match_behavior_default():
|
|
||||||
behavior = _get_no_match_behavior({})
|
|
||||||
assert "project" in behavior.lower()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit: 2.9 — device offline ───────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_2_9_device_offline():
|
|
||||||
"""2.9 No device online → status=error, no executor created."""
|
|
||||||
config = _make_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager(online=False)
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
|
||||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_fin.call_args
|
|
||||||
assert kwargs["status"] == "error"
|
|
||||||
assert any("not connected" in e for e in kwargs.get("errors", []))
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit: 2.10 — empty file ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_2_10_empty_file():
|
|
||||||
"""2.10 File with empty content → skipped, items_processed=0, success."""
|
|
||||||
config = _make_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
executor, calls = _make_executor(
|
|
||||||
file_path="/emails/empty.html",
|
|
||||||
file_content="",
|
|
||||||
projects=[_PROJECTS["alpha"]],
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
|
||||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_fin.call_args
|
|
||||||
assert kwargs["items_processed"] == 0
|
|
||||||
assert kwargs["status"] == "success"
|
|
||||||
assert kwargs["items_created"] == 0
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit: 2.8 — items_created count ─────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_2_8_items_created_count():
|
|
||||||
"""2.8 items_created == number of create_* tool calls per run."""
|
|
||||||
config = _make_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
executor, _calls = _make_executor(
|
|
||||||
file_path="/emails/action.html",
|
|
||||||
file_content="<html><body><p>Fix the login bug in Project Alpha.</p></body></html>",
|
|
||||||
projects=[_PROJECTS["alpha"]],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_run_agent(*, _tool_calls_out=None, **kw) -> str:
|
|
||||||
if _tool_calls_out is not None:
|
|
||||||
_tool_calls_out.extend(["create_task", "create_note", "update_task"])
|
|
||||||
return "Done."
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
|
||||||
patch("app.core.agent_runner._run_agent_with_tools", side_effect=mock_run_agent), \
|
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
|
||||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_fin.call_args
|
|
||||||
# Only create_task + create_note count (not update_task).
|
|
||||||
assert kwargs["items_created"] == 2
|
|
||||||
assert kwargs["items_processed"] == 1
|
|
||||||
|
|
||||||
|
|
||||||
# ── Eval: 2.1–2.7 — fixture-driven, real LLM + Langfuse scoring ──────────
|
|
||||||
#
|
|
||||||
# Cases loaded from tests/fixtures/agent_runner_v2/cases.yaml.
|
|
||||||
# Supported assertions (from YAML):
|
|
||||||
# expect_insert: <table> → at least 1 insert in that table
|
|
||||||
# expect_no_insert: true → zero inserts in any table
|
|
||||||
# expect_project_id: <id> → any insert carries this projectId
|
|
||||||
# expect_dedup: true → task inserts == 0 OR task updates >= 1
|
|
||||||
# ─────────────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.eval
|
|
||||||
async def test_eval_runner(runner_case, pytestconfig):
|
|
||||||
"""Parametrized eval test — one invocation per YAML case."""
|
|
||||||
case: dict = runner_case
|
|
||||||
data_dir = _fixtures_dir(pytestconfig) / "data"
|
|
||||||
file_content = _read_case_file(case, data_dir)
|
|
||||||
projects = _resolve_projects(case.get("projects", []))
|
|
||||||
|
|
||||||
config = _make_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
executor, calls = _make_executor(
|
|
||||||
file_path=case["file_path"],
|
|
||||||
file_content=file_content,
|
|
||||||
projects=projects,
|
|
||||||
existing_tasks=case.get("existing_tasks"),
|
|
||||||
existing_notes=case.get("existing_notes"),
|
|
||||||
existing_timelines=case.get("existing_timelines"),
|
|
||||||
)
|
|
||||||
|
|
||||||
lf = get_langfuse()
|
|
||||||
obs_ctx = lf.start_as_current_observation(
|
|
||||||
name=f"eval-runner-{case['id']}-{case.get('score_name', 'unknown').replace('.', '-')}",
|
|
||||||
metadata={"step": "2", "case_id": case["id"]},
|
|
||||||
) if lf else nullcontext()
|
|
||||||
|
|
||||||
with obs_ctx as obs:
|
|
||||||
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
|
||||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_fin.call_args
|
|
||||||
inserts = [c for c in calls if c["action"] == "insert"]
|
|
||||||
score, comment = _evaluate_case(case, calls, kwargs)
|
|
||||||
|
|
||||||
if obs is not None:
|
|
||||||
obs.score(
|
|
||||||
name=case.get("score_name", f"runner.case_{case['id']}"),
|
|
||||||
value=score,
|
|
||||||
comment=comment,
|
|
||||||
)
|
|
||||||
|
|
||||||
if lf:
|
|
||||||
lf.flush()
|
|
||||||
|
|
||||||
assert score == 1.0, f"[{case['id']}] {case.get('description', '')} — {comment}"
|
|
||||||
|
|
||||||
|
|
||||||
def _evaluate_case(case: dict, calls: list[dict], finalize_kwargs: dict) -> tuple[float, str]:
|
|
||||||
"""Return (score, comment) for a YAML case given the captured executor calls."""
|
|
||||||
inserts = [c for c in calls if c["action"] == "insert"]
|
|
||||||
|
|
||||||
if case.get("expect_no_insert"):
|
|
||||||
score = 1.0 if len(inserts) == 0 else 0.0
|
|
||||||
return score, f"inserts={len(inserts)} (expected 0)"
|
|
||||||
|
|
||||||
if "expect_insert" in case:
|
|
||||||
tables = case["expect_insert"]
|
|
||||||
if isinstance(tables, str):
|
|
||||||
tables = [tables]
|
|
||||||
missing = [t for t in tables if not any(c["table"] == t for c in inserts)]
|
|
||||||
score = 1.0 if not missing else 0.0
|
|
||||||
counts = {t: sum(1 for c in inserts if c["table"] == t) for t in tables}
|
|
||||||
return score, f"inserts={counts}" + (f" missing={missing}" if missing else "")
|
|
||||||
|
|
||||||
if "expect_project_id" in case:
|
|
||||||
expected_pid = case["expect_project_id"]
|
|
||||||
correct = any(c.get("data", {}).get("projectId") == expected_pid for c in inserts)
|
|
||||||
score = 1.0 if correct else 0.0
|
|
||||||
all_pids = [c.get("data", {}).get("projectId") for c in inserts]
|
|
||||||
return score, f"projectIds={all_pids} (expected {expected_pid!r})"
|
|
||||||
|
|
||||||
if case.get("expect_dedup"):
|
|
||||||
task_creates = [c for c in inserts if c["table"] == "tasks"]
|
|
||||||
task_updates = [c for c in calls if c["action"] == "update" and c["table"] == "tasks"]
|
|
||||||
score = 1.0 if len(task_creates) == 0 or len(task_updates) >= 1 else 0.0
|
|
||||||
return score, f"task_creates={len(task_creates)} task_updates={len(task_updates)}"
|
|
||||||
|
|
||||||
return 0.0, "no assertion defined in case"
|
|
||||||
@@ -1,184 +0,0 @@
|
|||||||
"""Unit tests for Step 1 file classification (_classify_file).
|
|
||||||
|
|
||||||
These tests call the real LLM so they require OPENAI_API_KEY / LLM env vars.
|
|
||||||
Run with: pytest tests/test_classify_file.py -v
|
|
||||||
|
|
||||||
To run a quick manual check against a real file without the full UI:
|
|
||||||
python -m tests.test_classify_file <path/to/file.txt> [project_name...]
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.core.agent_runner import _classify_file
|
|
||||||
|
|
||||||
|
|
||||||
# ── Fixtures ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
PROJECTS_SAMPLE = [
|
|
||||||
{
|
|
||||||
"id": "aaaa-0001-0000-0000-000000000001",
|
|
||||||
"name": "ARPA Sicilia POC",
|
|
||||||
"status": "active",
|
|
||||||
"aiSummary": "Proof of concept for AI features targeting ARPA Sicilia agency.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "bbbb-0002-0000-0000-000000000002",
|
|
||||||
"name": "SNAM AI Meeting Prep",
|
|
||||||
"status": "active",
|
|
||||||
"aiSummary": "AI-assisted preparation of meeting materials for SNAM.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "cccc-0003-0000-0000-000000000003",
|
|
||||||
"name": "SFERA+ Wave 2",
|
|
||||||
"status": "active",
|
|
||||||
"aiSummary": "Second wave of the SFERA+ whitelist project.",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
ARPA_EMAIL = """\
|
|
||||||
to: roberto.musso@hpe.com; luca.tondin@hpecds.com
|
|
||||||
isImportance: normal
|
|
||||||
hasAttachment: True
|
|
||||||
---
|
|
||||||
## Body
|
|
||||||
Buongiorno,
|
|
||||||
|
|
||||||
In riferimento alla riunione di ieri sul POC ARPA Sicilia, vi invio il riassunto
|
|
||||||
dei deliverable concordati:
|
|
||||||
- Preparare demo entro il 30 marzo
|
|
||||||
- Condividere documentazione tecnica con il team ARPA
|
|
||||||
- Fissare call di follow-up la prossima settimana
|
|
||||||
|
|
||||||
Cordiali saluti
|
|
||||||
Roberto Marchetti
|
|
||||||
"""
|
|
||||||
|
|
||||||
SNAM_EMAIL = """\
|
|
||||||
to: roberto.musso@hpe.com
|
|
||||||
isImportance: high
|
|
||||||
hasAttachment: False
|
|
||||||
---
|
|
||||||
## Body
|
|
||||||
Ciao,
|
|
||||||
ti invio l'agenda per la riunione SNAM di domani.
|
|
||||||
Per favore conferma la tua presenza.
|
|
||||||
"""
|
|
||||||
|
|
||||||
UNRELATED_EMAIL = """\
|
|
||||||
to: roberto.musso@hpe.com
|
|
||||||
isImportance: normal
|
|
||||||
---
|
|
||||||
## Body
|
|
||||||
Benvenuto nel programma HPE Employee Learning Series.
|
|
||||||
Completa la formazione richiesta entro la fine del trimestre.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tests ─────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_classify_arpa_matches_existing():
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path="arpa_email.txt",
|
|
||||||
file_content=ARPA_EMAIL,
|
|
||||||
projects=PROJECTS_SAMPLE,
|
|
||||||
config_data_types=["tasks", "notes", "timelines"],
|
|
||||||
)
|
|
||||||
assert project_id == "aaaa-0001-0000-0000-000000000001", (
|
|
||||||
f"Expected ARPA project, got project_id={project_id!r} new_name={new_name!r}"
|
|
||||||
)
|
|
||||||
assert new_name is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_classify_snam_matches_existing():
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path="snam_email.txt",
|
|
||||||
file_content=SNAM_EMAIL,
|
|
||||||
projects=PROJECTS_SAMPLE,
|
|
||||||
config_data_types=["tasks", "notes"],
|
|
||||||
)
|
|
||||||
assert project_id == "bbbb-0002-0000-0000-000000000002", (
|
|
||||||
f"Expected SNAM project, got project_id={project_id!r} new_name={new_name!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_classify_unrelated_returns_new():
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path="learning_email.txt",
|
|
||||||
file_content=UNRELATED_EMAIL,
|
|
||||||
projects=PROJECTS_SAMPLE,
|
|
||||||
config_data_types=["tasks", "notes"],
|
|
||||||
)
|
|
||||||
assert project_id == "new"
|
|
||||||
assert new_name is not None # LLM should suggest a name
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_classify_empty_file_returns_new():
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path="empty.txt",
|
|
||||||
file_content=" ",
|
|
||||||
projects=PROJECTS_SAMPLE,
|
|
||||||
config_data_types=["tasks"],
|
|
||||||
)
|
|
||||||
assert project_id == "new"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_classify_no_projects_returns_new():
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path="arpa_email.txt",
|
|
||||||
file_content=ARPA_EMAIL,
|
|
||||||
projects=[],
|
|
||||||
config_data_types=["tasks", "notes"],
|
|
||||||
)
|
|
||||||
assert project_id == "new"
|
|
||||||
assert new_name is not None
|
|
||||||
|
|
||||||
|
|
||||||
# ── CLI quick-test runner ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _cli_test(file_path: str, project_names: list[str]) -> None:
|
|
||||||
"""Run Step 1 classification against a real file from the CLI."""
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
content = Path(file_path).read_text(encoding="utf-8", errors="replace")
|
|
||||||
projects = [
|
|
||||||
{"id": f"test-id-{i:04d}", "name": name, "status": "active", "aiSummary": ""}
|
|
||||||
for i, name in enumerate(project_names)
|
|
||||||
]
|
|
||||||
|
|
||||||
print(f"\nClassifying: {file_path}")
|
|
||||||
print(f"Projects in context: {[p['name'] for p in projects]}\n")
|
|
||||||
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path=file_path,
|
|
||||||
file_content=content,
|
|
||||||
projects=projects,
|
|
||||||
config_data_types=["tasks", "notes", "timelines"],
|
|
||||||
)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"project_id": project_id,
|
|
||||||
"matched_name": next((p["name"] for p in projects if p["id"] == project_id), None),
|
|
||||||
"new_project_name": new_name,
|
|
||||||
"domains": domains,
|
|
||||||
}
|
|
||||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
if len(sys.argv) < 2:
|
|
||||||
print("Usage: python -m tests.test_classify_file <file_path> [project_name ...]")
|
|
||||||
sys.exit(1)
|
|
||||||
asyncio.run(_cli_test(sys.argv[1], sys.argv[2:]))
|
|
||||||
@@ -1,288 +0,0 @@
|
|||||||
"""Unit tests for single-agent deep_agent flows with mocked tool results."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import date, timedelta
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
|
||||||
|
|
||||||
from app.core.deep_agent import (
|
|
||||||
_infer_floating_domain,
|
|
||||||
_normalize_tagged_list_lines,
|
|
||||||
run_floating,
|
|
||||||
run_floating_stream,
|
|
||||||
run_home,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeTool:
|
|
||||||
name = "list_tasks"
|
|
||||||
|
|
||||||
async def ainvoke(self, args):
|
|
||||||
return {"rows": [{"id": "task-1", "title": "Mock Task"}], "echo": args}
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeLLM:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.agent_calls = 0
|
|
||||||
|
|
||||||
def bind_tools(self, _tools):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def ainvoke(self, messages):
|
|
||||||
system_prompt = str(getattr(messages[0], "content", "")) if messages else ""
|
|
||||||
if "strict domain classifier" in system_prompt:
|
|
||||||
return AIMessage(content='{"type":"timeline","id":"tl-1","section":null}')
|
|
||||||
|
|
||||||
self.agent_calls += 1
|
|
||||||
if self.agent_calls == 1:
|
|
||||||
return AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"id": "call-1",
|
|
||||||
"name": "list_tasks",
|
|
||||||
"args": {"project_id": "proj-1"},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
|
|
||||||
assert tool_messages, "Expected at least one tool message"
|
|
||||||
return AIMessage(content=f"Final answer from mocked tool: {tool_messages[-1].content}")
|
|
||||||
|
|
||||||
async def astream(self, _messages):
|
|
||||||
yield SimpleNamespace(content="stream-")
|
|
||||||
yield SimpleNamespace(content="ok")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_home_uses_mocked_tool_result():
|
|
||||||
fake_llm = _FakeLLM()
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
||||||
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
|
||||||
):
|
|
||||||
out = await run_home("user-1", "list my tasks", {})
|
|
||||||
|
|
||||||
assert "Final answer from mocked tool" in out
|
|
||||||
assert "Mock Task" in out
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
|
|
||||||
fake_llm = _FakeLLM()
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
||||||
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
|
||||||
):
|
|
||||||
events = []
|
|
||||||
async for event in run_floating_stream(
|
|
||||||
"user-1",
|
|
||||||
"show me timeline updates",
|
|
||||||
{"scope": {"type": "timeline", "id": "tl-1"}},
|
|
||||||
):
|
|
||||||
events.append(event)
|
|
||||||
|
|
||||||
assert events[0] == (
|
|
||||||
"floating_domain",
|
|
||||||
{"type": "timeline", "id": "tl-1", "section": None},
|
|
||||||
)
|
|
||||||
assert ("token", "stream-") in events
|
|
||||||
assert ("token", "ok") in events
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
|
|
||||||
class _ClassifierOnlyLLM:
|
|
||||||
async def ainvoke(self, _messages):
|
|
||||||
return AIMessage(
|
|
||||||
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=_ClassifierOnlyLLM()):
|
|
||||||
domain = await _infer_floating_domain(
|
|
||||||
"Quali sono i miei task per il progetto X",
|
|
||||||
{
|
|
||||||
"scope": {"type": "timeline"},
|
|
||||||
"resolved_project_id": "213213-312321-312312-421321",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert domain == {
|
|
||||||
"type": "project",
|
|
||||||
"id": "213213-312321-312312-421321",
|
|
||||||
"section": "task",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
|
|
||||||
raw = (
|
|
||||||
"Certo!\n\n"
|
|
||||||
"1. **Task A** — priorita high <task>[task-1]</task>\n"
|
|
||||||
"2. **Task B** — priorita medium <task>[task-2]</task>\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
out = _normalize_tagged_list_lines(raw, "quali sono le prossime attivita?")
|
|
||||||
|
|
||||||
assert "<task>[task-1]</task>" in out
|
|
||||||
assert "<task>[task-2]</task>" in out
|
|
||||||
assert "Task A" not in out
|
|
||||||
assert "Task B" not in out
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_month_future_only():
|
|
||||||
today = date.today()
|
|
||||||
tomorrow = today + timedelta(days=1)
|
|
||||||
yesterday = today - timedelta(days=1)
|
|
||||||
next_month = (today.replace(day=28) + timedelta(days=5)).replace(day=1)
|
|
||||||
|
|
||||||
raw = "\n".join(
|
|
||||||
[
|
|
||||||
f"- Milestone old — {yesterday.strftime('%d/%m/%Y')} <timeline>[tl-old]</timeline>",
|
|
||||||
f"- Milestone next — {tomorrow.strftime('%d/%m/%Y')} <timeline>[tl-next]</timeline>",
|
|
||||||
f"- Milestone future — {next_month.strftime('%d/%m/%Y')} <timeline>[tl-future]</timeline>",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
out = _normalize_tagged_list_lines(raw, "invece i miei eventi prossimi?")
|
|
||||||
|
|
||||||
assert "<timeline>[tl-next]</timeline>" in out
|
|
||||||
assert "<timeline>[tl-old]</timeline>" not in out
|
|
||||||
assert "<timeline>[tl-future]</timeline>" not in out
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_strips_xml_like_tags_from_final_text():
|
|
||||||
fake_llm = _FakeLLM()
|
|
||||||
|
|
||||||
async def _fake_run_single_agent(**_kwargs):
|
|
||||||
return (
|
|
||||||
"Hai 1 task:\\n"
|
|
||||||
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
||||||
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
|
||||||
):
|
|
||||||
text, _domain = await run_floating(
|
|
||||||
"user-1",
|
|
||||||
"quali task ho?",
|
|
||||||
{"scope": {"type": "task"}},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "<task>" not in text
|
|
||||||
assert "</task>" not in text
|
|
||||||
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
|
|
||||||
fake_llm = _FakeLLM()
|
|
||||||
|
|
||||||
async def _fake_stream(**_kwargs):
|
|
||||||
yield "token", "Hai 1 task:\\n"
|
|
||||||
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
||||||
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
|
||||||
):
|
|
||||||
events = []
|
|
||||||
async for event in run_floating_stream(
|
|
||||||
"user-1",
|
|
||||||
"quali task ho?",
|
|
||||||
{"scope": {"type": "task"}},
|
|
||||||
):
|
|
||||||
events.append(event)
|
|
||||||
|
|
||||||
token_events = [str(data) for event_type, data in events if event_type == "token"]
|
|
||||||
combined = "".join(token_events)
|
|
||||||
assert "<task>" not in combined
|
|
||||||
assert "</task>" not in combined
|
|
||||||
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty():
|
|
||||||
class _NoChunkLLM:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.calls = 0
|
|
||||||
|
|
||||||
def bind_tools(self, _tools):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def ainvoke(self, _messages):
|
|
||||||
self.calls += 1
|
|
||||||
if self.calls == 1:
|
|
||||||
return AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"id": "call-1",
|
|
||||||
"name": "list_tasks",
|
|
||||||
"args": {},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
return AIMessage(content="No notes found.")
|
|
||||||
|
|
||||||
async def astream(self, _messages):
|
|
||||||
if False:
|
|
||||||
yield None
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=_NoChunkLLM()), patch(
|
|
||||||
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
|
||||||
):
|
|
||||||
events = []
|
|
||||||
async for event in run_floating_stream(
|
|
||||||
"user-1",
|
|
||||||
"quali sono le note?",
|
|
||||||
{"scope": {"type": "note"}},
|
|
||||||
):
|
|
||||||
events.append(event)
|
|
||||||
|
|
||||||
assert events[0][0] == "floating_domain"
|
|
||||||
assert ("token", "No notes found.") in events
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_returns_fallback_when_sanitization_would_empty_text():
|
|
||||||
fake_llm = _FakeLLM()
|
|
||||||
|
|
||||||
async def _fake_run_single_agent(**_kwargs):
|
|
||||||
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
||||||
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
|
||||||
):
|
|
||||||
text, _domain = await run_floating(
|
|
||||||
"user-1",
|
|
||||||
"quali task ho?",
|
|
||||||
{"scope": {"type": "task"}},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert text == "No results found."
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text():
|
|
||||||
fake_llm = _FakeLLM()
|
|
||||||
|
|
||||||
async def _fake_stream(**_kwargs):
|
|
||||||
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
||||||
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
|
||||||
):
|
|
||||||
events = []
|
|
||||||
async for event in run_floating_stream(
|
|
||||||
"user-1",
|
|
||||||
"quali task ho?",
|
|
||||||
{"scope": {"type": "task"}},
|
|
||||||
):
|
|
||||||
events.append(event)
|
|
||||||
|
|
||||||
assert ("token", "No results found.") in events
|
|
||||||
@@ -1,349 +0,0 @@
|
|||||||
"""Tests for Local Agent V2 journey setup (Step 4).
|
|
||||||
|
|
||||||
Covers the chatbot journey that produces a structured AgentConfig JSON
|
|
||||||
instead of a freeform prompt_template string.
|
|
||||||
|
|
||||||
Unit tests (no LLM)
|
|
||||||
--------------------
|
|
||||||
4.6a _extract_agent_config: valid JSON → returns serialised config
|
|
||||||
4.6b _extract_agent_config: invalid JSON → returns None
|
|
||||||
4.6c _extract_agent_config: markers absent → returns None
|
|
||||||
4.6d _extract_agent_config: only START marker → returns None
|
|
||||||
4.6e Session not found → done=True, agent_config=None
|
|
||||||
4.6f Nudge uses AGENT_CONFIG_START/END markers (not old PROMPT_TEMPLATE)
|
|
||||||
|
|
||||||
Eval tests (real LLM + Langfuse scoring)
|
|
||||||
-----------------------------------------
|
|
||||||
Cases are defined in tests/fixtures/journey_v2/cases.yaml.
|
|
||||||
Email HTML files live in tests/fixtures/journey_v2/data/.
|
|
||||||
Use --journey-dir to point at a custom folder (same structure required).
|
|
||||||
|
|
||||||
Run:
|
|
||||||
pytest tests/test_journey_v2.py -v
|
|
||||||
pytest tests/test_journey_v2.py -v -k "4_6" # unit only
|
|
||||||
pytest tests/test_journey_v2.py -v -k "eval" # LLM evals only
|
|
||||||
pytest tests/test_journey_v2.py -v --journey-dir /p # custom fixtures
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from app.api.routes.agent_setup import (
|
|
||||||
_CONFIG_END,
|
|
||||||
_CONFIG_START,
|
|
||||||
_MAX_TURNS,
|
|
||||||
_extract_agent_config,
|
|
||||||
_sessions,
|
|
||||||
handle_journey_message,
|
|
||||||
handle_journey_start,
|
|
||||||
)
|
|
||||||
from app.core.langfuse_client import get_langfuse
|
|
||||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
|
||||||
from app.schemas import AgentConfig
|
|
||||||
from tests.conftest import TEST_USER_IDS
|
|
||||||
|
|
||||||
# ── Constants ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
_USER_ID = TEST_USER_IDS["power"]
|
|
||||||
|
|
||||||
_DEFAULT_FIXTURE_DIR = Path(__file__).parent / "fixtures" / "journey_v2"
|
|
||||||
|
|
||||||
# ── Fixture loading ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _fixtures_dir(config) -> Path:
|
|
||||||
override = config.getoption("--journey-dir")
|
|
||||||
return Path(override) if override else _DEFAULT_FIXTURE_DIR
|
|
||||||
|
|
||||||
|
|
||||||
def _load_cases(config) -> list[dict]:
|
|
||||||
return yaml.safe_load(
|
|
||||||
(_fixtures_dir(config) / "cases.yaml").read_text(encoding="utf-8")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _read_data_file(filename: str, fixtures_dir: Path) -> str:
|
|
||||||
return (fixtures_dir / "data" / filename).read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
# ── pytest_generate_tests ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
|
||||||
if "journey_case" not in metafunc.fixturenames:
|
|
||||||
return
|
|
||||||
cases = _load_cases(metafunc.config)
|
|
||||||
metafunc.parametrize("journey_case", cases, ids=[c["id"] for c in cases])
|
|
||||||
|
|
||||||
|
|
||||||
# ── Executor builder ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _make_fs_executor(directory_files: list[dict], fixtures_dir: Path):
|
|
||||||
"""Return an async callback that simulates filesystem tool responses.
|
|
||||||
|
|
||||||
Matches the signature expected by ``set_client_executor`` / ``execute_on_client``:
|
|
||||||
receives the full ``payload`` dict and returns a result dict.
|
|
||||||
|
|
||||||
``directory_files`` is a list of ``{path, content_file}`` dicts;
|
|
||||||
``content_file`` is relative to ``fixtures_dir/data/``.
|
|
||||||
"""
|
|
||||||
file_map: dict[str, str] = {
|
|
||||||
entry["path"]: _read_data_file(entry["content_file"], fixtures_dir)
|
|
||||||
for entry in directory_files
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _executor(payload: dict) -> dict:
|
|
||||||
action = payload.get("action", "")
|
|
||||||
data = payload.get("data") or {}
|
|
||||||
|
|
||||||
if action == "list_directory":
|
|
||||||
return {"entries": [
|
|
||||||
{"type": "file", "name": p.split("/")[-1], "path": p}
|
|
||||||
for p in file_map
|
|
||||||
]}
|
|
||||||
|
|
||||||
if action == "read_file_content":
|
|
||||||
path = data.get("path", "")
|
|
||||||
return {"content": file_map.get(path, "")}
|
|
||||||
|
|
||||||
if action == "get_file_metadata":
|
|
||||||
path = data.get("path", "")
|
|
||||||
name = path.split("/")[-1]
|
|
||||||
ext = "." + name.rsplit(".", 1)[-1] if "." in name else ""
|
|
||||||
return {"name": name, "extension": ext, "size": 1024,
|
|
||||||
"createdAt": None, "modifiedAt": None}
|
|
||||||
|
|
||||||
return {}
|
|
||||||
|
|
||||||
return _executor
|
|
||||||
|
|
||||||
|
|
||||||
# ── Journey runner helper ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_journey(user_id: str, case: dict, executor) -> dict[str, Any]:
|
|
||||||
"""Drive start + all user_messages for a case. Returns the final reply dict.
|
|
||||||
|
|
||||||
Mirrors ``device_ws._handle_journey_start/message``: sets the client
|
|
||||||
executor (so filesystem tools work) before each handler call.
|
|
||||||
"""
|
|
||||||
session_id = str(uuid.uuid4())
|
|
||||||
try:
|
|
||||||
set_client_executor(executor)
|
|
||||||
reply = await handle_journey_start(user_id, {
|
|
||||||
"agent_type": "local",
|
|
||||||
"directory": case["directory"],
|
|
||||||
"data_types": case["data_types"],
|
|
||||||
"session_id": session_id,
|
|
||||||
})
|
|
||||||
|
|
||||||
for msg in case.get("user_messages", []):
|
|
||||||
if reply.get("done"):
|
|
||||||
break
|
|
||||||
set_client_executor(executor)
|
|
||||||
reply = await handle_journey_message(user_id, {
|
|
||||||
"session_id": reply["session_id"],
|
|
||||||
"message": msg,
|
|
||||||
})
|
|
||||||
finally:
|
|
||||||
clear_client_executor()
|
|
||||||
_sessions.pop(session_id, None)
|
|
||||||
|
|
||||||
return reply
|
|
||||||
|
|
||||||
|
|
||||||
# ── Assertion helper ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _evaluate_case(case: dict, reply: dict) -> tuple[float, str]:
|
|
||||||
"""Return (score, comment) for a journey case given the final reply dict."""
|
|
||||||
if case.get("expect_question"):
|
|
||||||
has_q = "?" in reply.get("message", "")
|
|
||||||
return (1.0 if has_q else 0.0), f"first_reply_has_question={has_q}"
|
|
||||||
|
|
||||||
if case.get("expect_done") and not reply.get("done"):
|
|
||||||
return 0.0, "expected done=True but journey did not complete"
|
|
||||||
|
|
||||||
agent_config_raw = reply.get("agent_config")
|
|
||||||
|
|
||||||
if case.get("expect_valid_config"):
|
|
||||||
if not agent_config_raw:
|
|
||||||
return 0.0, "agent_config is None"
|
|
||||||
try:
|
|
||||||
parsed = AgentConfig.model_validate_json(agent_config_raw)
|
|
||||||
valid = len(parsed.content_types) > 0
|
|
||||||
return (1.0 if valid else 0.0), f"content_types={len(parsed.content_types)}"
|
|
||||||
except Exception as exc:
|
|
||||||
return 0.0, f"parse error: {exc}"
|
|
||||||
|
|
||||||
if case.get("expect_content_type_id"):
|
|
||||||
expected_id = case["expect_content_type_id"]
|
|
||||||
if not agent_config_raw:
|
|
||||||
return 0.0, "agent_config is None"
|
|
||||||
try:
|
|
||||||
parsed = AgentConfig.model_validate_json(agent_config_raw)
|
|
||||||
ids = [ct.id for ct in parsed.content_types]
|
|
||||||
found = expected_id in ids
|
|
||||||
return (1.0 if found else 0.0), f"content_type_ids={ids}, expected={expected_id}"
|
|
||||||
except Exception as exc:
|
|
||||||
return 0.0, f"parse error: {exc}"
|
|
||||||
|
|
||||||
if case.get("expect_extraction_contains"):
|
|
||||||
keyword = case["expect_extraction_contains"].lower()
|
|
||||||
if not agent_config_raw:
|
|
||||||
return 0.0, "agent_config is None"
|
|
||||||
try:
|
|
||||||
parsed = AgentConfig.model_validate_json(agent_config_raw)
|
|
||||||
if not parsed.content_types:
|
|
||||||
return 0.0, "no content_types in config"
|
|
||||||
prompt = parsed.content_types[0].extraction_prompt.lower()
|
|
||||||
found = keyword in prompt
|
|
||||||
return (1.0 if found else 0.0), f"keyword='{keyword}' in extraction_prompt={found}"
|
|
||||||
except Exception as exc:
|
|
||||||
return 0.0, f"parse error: {exc}"
|
|
||||||
|
|
||||||
if case.get("expect_global_rules"):
|
|
||||||
if not agent_config_raw:
|
|
||||||
return 0.0, "agent_config is None"
|
|
||||||
try:
|
|
||||||
parsed = AgentConfig.model_validate_json(agent_config_raw)
|
|
||||||
has_rules = len(parsed.global_rules) > 0
|
|
||||||
return (1.0 if has_rules else 0.0), f"global_rules={parsed.global_rules}"
|
|
||||||
except Exception as exc:
|
|
||||||
return 0.0, f"parse error: {exc}"
|
|
||||||
|
|
||||||
return 1.0, "no specific assertion"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit tests ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_4_6a_extract_valid_json():
|
|
||||||
"""_extract_agent_config: valid JSON between markers → returns serialised config."""
|
|
||||||
config = AgentConfig(
|
|
||||||
content_types=[],
|
|
||||||
global_rules=["No project = no entity"],
|
|
||||||
data_types=["tasks"],
|
|
||||||
)
|
|
||||||
text = f"Some preamble\n{_CONFIG_START}\n{config.model_dump_json()}\n{_CONFIG_END}\nTrailing"
|
|
||||||
result = _extract_agent_config(text)
|
|
||||||
assert result is not None
|
|
||||||
parsed = AgentConfig.model_validate_json(result)
|
|
||||||
assert parsed.global_rules == ["No project = no entity"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_4_6b_extract_invalid_json():
|
|
||||||
"""_extract_agent_config: malformed JSON between markers → returns None."""
|
|
||||||
text = f"{_CONFIG_START}\n{{not: valid json\n{_CONFIG_END}"
|
|
||||||
assert _extract_agent_config(text) is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_4_6c_extract_markers_absent():
|
|
||||||
"""_extract_agent_config: no markers at all → returns None."""
|
|
||||||
assert _extract_agent_config("No markers here at all") is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_4_6d_extract_only_start_marker():
|
|
||||||
"""_extract_agent_config: START without END → returns None."""
|
|
||||||
assert _extract_agent_config(f"text {_CONFIG_START} no end marker") is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_4_6e_session_not_found():
|
|
||||||
"""4.6e Session not found → done=True, agent_config=None, informative message."""
|
|
||||||
reply = await handle_journey_message(_USER_ID, {
|
|
||||||
"session_id": "nonexistent-session-id",
|
|
||||||
"message": "Hello",
|
|
||||||
})
|
|
||||||
assert reply["done"] is True
|
|
||||||
assert reply["agent_config"] is None
|
|
||||||
assert "not found" in reply["message"].lower() or "expired" in reply["message"].lower()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_4_6f_nudge_uses_new_markers():
|
|
||||||
"""4.6f Nudge injected after max turns uses AGENT_CONFIG markers, not PROMPT_TEMPLATE."""
|
|
||||||
session_id = str(uuid.uuid4())
|
|
||||||
captured_histories: list[list[dict]] = []
|
|
||||||
|
|
||||||
async def _mock_llm(system_prompt, history, tools, **kwargs) -> str:
|
|
||||||
captured_histories.append(list(history))
|
|
||||||
# Return plain text — no markers — to trigger the nudge path.
|
|
||||||
return "I still need more information from you."
|
|
||||||
|
|
||||||
from app.api.routes.agent_setup import JourneySession
|
|
||||||
|
|
||||||
fake_session = JourneySession(
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=_USER_ID,
|
|
||||||
agent_type="local",
|
|
||||||
directory="/test",
|
|
||||||
data_types=["tasks"],
|
|
||||||
system_prompt="system",
|
|
||||||
langfuse_prompt=None,
|
|
||||||
)
|
|
||||||
# Fill history to the turn limit so the next message triggers the nudge.
|
|
||||||
for i in range(_MAX_TURNS):
|
|
||||||
fake_session.history.append({"role": "user", "content": f"msg {i}"})
|
|
||||||
fake_session.history.append({"role": "assistant", "content": "ok"})
|
|
||||||
_sessions[session_id] = fake_session
|
|
||||||
|
|
||||||
try:
|
|
||||||
with patch("app.api.routes.agent_setup._call_llm_with_tools", side_effect=_mock_llm):
|
|
||||||
await handle_journey_message(_USER_ID, {
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": "one more message to trigger nudge",
|
|
||||||
})
|
|
||||||
finally:
|
|
||||||
_sessions.pop(session_id, None)
|
|
||||||
|
|
||||||
# Second LLM call receives the nudge appended to history.
|
|
||||||
assert len(captured_histories) >= 2, "Expected ≥ 2 LLM calls (main reply + nudge)"
|
|
||||||
nudge_history = captured_histories[1]
|
|
||||||
user_msgs = " ".join(t["content"] for t in nudge_history if t["role"] == "user")
|
|
||||||
assert _CONFIG_START in user_msgs, f"Nudge must reference {_CONFIG_START}"
|
|
||||||
assert _CONFIG_END in user_msgs, f"Nudge must reference {_CONFIG_END}"
|
|
||||||
assert "PROMPT_TEMPLATE" not in user_msgs, "Old PROMPT_TEMPLATE markers must not appear in nudge"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Eval tests (real LLM + Langfuse) ─────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.eval
|
|
||||||
async def test_eval_journey(journey_case, pytestconfig):
|
|
||||||
"""Parametrized eval test — one invocation per YAML case."""
|
|
||||||
case: dict = journey_case
|
|
||||||
fixtures_dir = _fixtures_dir(pytestconfig)
|
|
||||||
executor = _make_fs_executor(case.get("directory_files", []), fixtures_dir)
|
|
||||||
|
|
||||||
lf = get_langfuse()
|
|
||||||
obs_ctx = lf.start_as_current_observation(
|
|
||||||
name=f"eval-journey-{case['id']}-{case.get('score_name', 'unknown').replace('.', '-')}",
|
|
||||||
metadata={"step": "4", "case_id": case["id"]},
|
|
||||||
) if lf else nullcontext()
|
|
||||||
|
|
||||||
with obs_ctx as obs:
|
|
||||||
reply = await _run_journey(_USER_ID, case, executor)
|
|
||||||
score, comment = _evaluate_case(case, reply)
|
|
||||||
|
|
||||||
if obs is not None:
|
|
||||||
obs.score(
|
|
||||||
name=case.get("score_name", f"journey.case_{case['id']}"),
|
|
||||||
value=score,
|
|
||||||
comment=comment,
|
|
||||||
)
|
|
||||||
|
|
||||||
if lf:
|
|
||||||
lf.flush()
|
|
||||||
|
|
||||||
assert score == 1.0, f"[{case['id']}] {case.get('description', '')} — {comment}"
|
|
||||||
@@ -110,32 +110,6 @@ async def test_enrich_context_returns_episodic_memory(db_session, user_with_key)
|
|||||||
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
|
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_enrich_context_filters_episodic_by_session_id(db_session, user_with_key):
|
|
||||||
target_session = str(uuid.uuid4())
|
|
||||||
other_session = str(uuid.uuid4())
|
|
||||||
db_session.add(MemoryEpisodic(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=USER_ID,
|
|
||||||
summary_encrypted=_enc("Target session memory"),
|
|
||||||
session_id=target_session,
|
|
||||||
))
|
|
||||||
db_session.add(MemoryEpisodic(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=USER_ID,
|
|
||||||
summary_encrypted=_enc("Other session memory"),
|
|
||||||
session_id=other_session,
|
|
||||||
))
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
middleware = MemoryMiddleware(db_session)
|
|
||||||
ctx = await middleware.enrich_context(USER_ID, "any message", session_id=target_session)
|
|
||||||
|
|
||||||
episodic = ctx.get("episodic_memory", [])
|
|
||||||
assert any("Target session" in s for s in episodic)
|
|
||||||
assert not any("Other session" in s for s in episodic)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
||||||
# Add one pattern above threshold and one below
|
# Add one pattern above threshold and one below
|
||||||
@@ -255,40 +229,6 @@ async def test_update_core_upsert(db_session, user_with_key):
|
|||||||
assert _dec(rows[0].value_encrypted) == "fr"
|
assert _dec(rows[0].value_encrypted) == "fr"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_core_block_edit_ops(db_session, user_with_key):
|
|
||||||
middleware = MemoryMiddleware(db_session)
|
|
||||||
|
|
||||||
await middleware.update_core(USER_ID, "human", "Name: Roberto")
|
|
||||||
await middleware.append_core(USER_ID, "human", "Timezone: Europe/Rome")
|
|
||||||
replaced = await middleware.replace_core(USER_ID, "human", "Roberto", "Robert")
|
|
||||||
|
|
||||||
blocks = await middleware.list_core_blocks(USER_ID)
|
|
||||||
human = next(b for b in blocks if b["label"] == "human")
|
|
||||||
|
|
||||||
assert replaced is True
|
|
||||||
assert "Name: Robert" in human["value"]
|
|
||||||
assert "Timezone: Europe/Rome" in human["value"]
|
|
||||||
|
|
||||||
deleted = await middleware.delete_core(USER_ID, "human")
|
|
||||||
assert deleted is True
|
|
||||||
assert await middleware.get_core_block(USER_ID, "human") is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_archival_and_recall_search_helpers(db_session, user_with_key):
|
|
||||||
middleware = MemoryMiddleware(db_session)
|
|
||||||
|
|
||||||
await middleware.insert_archival(USER_ID, "Project whitelist has release risk", source="assistant")
|
|
||||||
await middleware.store_episode(USER_ID, str(uuid.uuid4()), "How is whitelist?", "Whitelist is delayed")
|
|
||||||
|
|
||||||
arch = await middleware.search_archival(USER_ID, "whitelist", top_k=3)
|
|
||||||
rec = await middleware.search_recall(USER_ID, "delayed", top_k=3)
|
|
||||||
|
|
||||||
assert any("whitelist" in item.lower() for item in arch)
|
|
||||||
assert any("delayed" in item.lower() for item in rec)
|
|
||||||
|
|
||||||
|
|
||||||
# ── End-to-end WS: memory middleware is called during home_request ────────────
|
# ── End-to-end WS: memory middleware is called during home_request ────────────
|
||||||
|
|
||||||
def test_home_request_calls_memory_middleware(client):
|
def test_home_request_calls_memory_middleware(client):
|
||||||
@@ -300,20 +240,21 @@ def test_home_request_calls_memory_middleware(client):
|
|||||||
def __init__(self, db):
|
def __init__(self, db):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def enrich_context(self, user_id, message, **kwargs):
|
async def enrich_context(self, user_id, message):
|
||||||
enrich_calls.append((user_id, message))
|
enrich_calls.append((user_id, message))
|
||||||
return {"core_memory": {"tz": "UTC"}}
|
return {"core_memory": {"tz": "UTC"}}
|
||||||
|
|
||||||
async def store_episode(self, user_id, session_id, message, response, **kwargs):
|
async def store_episode(self, user_id, session_id, message, response):
|
||||||
store_calls.append((user_id, session_id, message, response))
|
store_calls.append((user_id, session_id, message, response))
|
||||||
|
|
||||||
token = make_jwt("power", user_id=USER_ID)
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
async def _mock_stream(user_id, message, context):
|
async def _mock_stream(user_id, message, context, db_session_factory=None):
|
||||||
# Verify memory context was injected
|
# Verify memory context was injected
|
||||||
assert context.get("core_memory") == {"tz": "UTC"}
|
assert context.get("core_memory") == {"tz": "UTC"}
|
||||||
yield "token", "Done"
|
yield ("token", "Done")
|
||||||
|
yield ("mutations", [])
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
|
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
|
||||||
|
|||||||
@@ -1,82 +1,214 @@
|
|||||||
"""Tests for app.core.output_formatter.StreamFormatter."""
|
"""Tests for app.core.output_formatter — HomeFormatter and FloatingFormatter."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.core.output_formatter import StreamFormatter
|
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
||||||
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
from app.schemas import (
|
||||||
|
WsFloatingDomain,
|
||||||
|
WsStreamEnd,
|
||||||
|
WsStreamStart,
|
||||||
|
WsStreamText,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _stream(*events: tuple[str, object]):
|
async def _stream(*events: tuple[str, object]):
|
||||||
|
"""Async generator that yields (event_type, data) tuples."""
|
||||||
for event in events:
|
for event in events:
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
|
|
||||||
async def _collect(formatter: StreamFormatter, event_stream):
|
async def collect(formatter, event_stream):
|
||||||
frames = []
|
frames = []
|
||||||
async for frame in formatter.format(event_stream):
|
async for frame in formatter.format(event_stream):
|
||||||
frames.append(frame)
|
frames.append(frame)
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
# ── HomeFormatter ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_formatter_text_stream() -> None:
|
async def test_home_formatter_plain_text():
|
||||||
formatter = StreamFormatter(request_id="req-1")
|
req_id = "req-1"
|
||||||
frames = await _collect(
|
events = [
|
||||||
formatter,
|
("token", "Hello world"),
|
||||||
_stream(("token", "Hello"), ("token", " world")),
|
("mutations", []),
|
||||||
)
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
assert isinstance(frames[0], WsStreamStart)
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
assert isinstance(frames[1], WsStreamText)
|
assert frames[0].request_id == req_id
|
||||||
assert frames[1].chunk == "Hello"
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
||||||
assert isinstance(frames[2], WsStreamText)
|
assert any("Hello world" in f.chunk for f in text_frames)
|
||||||
assert frames[2].chunk == " world"
|
|
||||||
assert isinstance(frames[-1], WsStreamEnd)
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_formatter_floating_domain_first() -> None:
|
async def test_home_formatter_entity_tags_passed_through():
|
||||||
formatter = StreamFormatter(request_id="req-2")
|
"""Entity tags are streamed as-is — the frontend parses them."""
|
||||||
frames = await _collect(
|
req_id = "req-2"
|
||||||
formatter,
|
events = [
|
||||||
_stream(
|
("token", "Here is your project:\n<project>[abc-123]</project>\nAll good."),
|
||||||
(
|
("mutations", []),
|
||||||
"floating_domain",
|
]
|
||||||
{"type": "node", "id": "n-1", "section": None},
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
),
|
frames = await collect(formatter, _stream(*events))
|
||||||
("token", "Summary"),
|
|
||||||
),
|
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
||||||
)
|
assert "<project>[abc-123]</project>" in text
|
||||||
|
assert "Here is your project:" in text
|
||||||
|
assert "All good." in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_multiple_tags_passed_through():
|
||||||
|
req_id = "req-3"
|
||||||
|
events = [
|
||||||
|
("token", "<project>[p1]</project>\nText\n<task>[t1,t2]</task>"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
||||||
|
assert "<project>[p1]</project>" in text
|
||||||
|
assert "<task>[t1,t2]</task>" in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_tool_end_ignored():
|
||||||
|
"""tool_end events are silently ignored by HomeFormatter."""
|
||||||
|
req_id = "req-4"
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "task_agent", "result": "3 tasks"}),
|
||||||
|
("token", "No tags here."),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
||||||
|
assert text == "No tags here."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_mutations_in_stream_end():
|
||||||
|
req_id = "req-5"
|
||||||
|
muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}]
|
||||||
|
events = [
|
||||||
|
("token", "Done"),
|
||||||
|
("mutations", muts),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
end_frame = frames[-1]
|
||||||
|
assert isinstance(end_frame, WsStreamEnd)
|
||||||
|
assert len(end_frame.mutations) == 1
|
||||||
|
assert end_frame.mutations[0]["action"] == "insert"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_frame_order():
|
||||||
|
"""stream_start is first, stream_end is last."""
|
||||||
|
req_id = "req-6"
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", [])))
|
||||||
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
|
# ── FloatingFormatter ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_domain_from_tool_end():
|
||||||
|
req_id = "pop-1"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "task_agent", "result": "ok"}),
|
||||||
|
("token", "Hello"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
assert isinstance(frames[0], WsFloatingDomain)
|
assert isinstance(frames[0], WsFloatingDomain)
|
||||||
assert frames[0].domain.type == "node"
|
assert frames[0].domain == "tasks"
|
||||||
assert frames[0].domain.id == "n-1"
|
assert frames[0].request_id == req_id
|
||||||
assert isinstance(frames[1], WsStreamStart)
|
|
||||||
assert isinstance(frames[2], WsStreamText)
|
|
||||||
assert frames[2].chunk == "Summary"
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_text_only():
|
||||||
|
req_id = "pop-2"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "timeline_agent", "result": "done"}),
|
||||||
|
("token", "Summary"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
assert isinstance(frames[0], WsFloatingDomain)
|
||||||
|
assert frames[0].domain == "timelines"
|
||||||
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
||||||
|
assert len(text_frames) == 1
|
||||||
|
assert text_frames[0].chunk == "Summary"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_no_entity_tags():
|
||||||
|
"""FloatingFormatter never emits entity tag blocks."""
|
||||||
|
req_id = "pop-3"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "note_agent", "result": "data"}),
|
||||||
|
("token", "some text"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
# Only expected frame types
|
||||||
|
for f in frames:
|
||||||
|
assert isinstance(f, (WsFloatingDomain, WsStreamStart, WsStreamText, WsStreamEnd))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_end_frame():
|
||||||
|
req_id = "pop-4"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "project_agent", "result": "ok"}),
|
||||||
|
("token", "Done"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
assert isinstance(frames[-1], WsStreamEnd)
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_formatter_ignores_unknown_events() -> None:
|
async def test_floating_formatter_default_domain_on_early_token():
|
||||||
formatter = StreamFormatter(request_id="req-3")
|
"""When the first event is a token (no tool_end yet), default to 'tasks'."""
|
||||||
frames = await _collect(
|
req_id = "pop-5"
|
||||||
formatter,
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
_stream(("tool_end", {"name": "x"}), ("token", "ok")),
|
events = [("token", "hi"), ("mutations", [])]
|
||||||
)
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
assert isinstance(frames[0], WsFloatingDomain)
|
||||||
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
assert frames[0].domain == "tasks"
|
||||||
assert len(text_frames) == 1
|
|
||||||
assert text_frames[0].chunk == "ok"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_formatter_empty_stream_still_brackets() -> None:
|
async def test_floating_formatter_mutations_in_stream_end():
|
||||||
formatter = StreamFormatter(request_id="req-4")
|
req_id = "pop-6"
|
||||||
frames = await _collect(formatter, _stream())
|
muts = [{"action": "update", "table": "tasks", "data": {"id": "t2"}}]
|
||||||
|
events = [
|
||||||
|
("token", "Updated"),
|
||||||
|
("mutations", muts),
|
||||||
|
]
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
assert len(frames) == 2
|
end_frame = frames[-1]
|
||||||
assert isinstance(frames[0], WsStreamStart)
|
assert isinstance(end_frame, WsStreamEnd)
|
||||||
assert isinstance(frames[1], WsStreamEnd)
|
assert len(end_frame.mutations) == 1
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ class TestPluginRegistry:
|
|||||||
async def test_list_filter_by_query(
|
async def test_list_filter_by_query(
|
||||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
) -> None:
|
) -> None:
|
||||||
result = await reg.list_plugins(db_session, query="time")
|
result = await reg.list_plugins(db_session, query="time tracker")
|
||||||
assert result.total == 1
|
assert result.total == 1
|
||||||
assert result.plugins[0].id == "plugin-time-tracker"
|
assert result.plugins[0].id == "plugin-time-tracker"
|
||||||
|
|
||||||
|
|||||||
@@ -1,98 +0,0 @@
|
|||||||
"""Tests for the preprocessor system (Step 1 — Local Agent V2).
|
|
||||||
|
|
||||||
Run:
|
|
||||||
pytest tests/test_preprocessors.py -v
|
|
||||||
pytest tests/test_preprocessors.py -v --preprocess-dir /path/to/folder
|
|
||||||
|
|
||||||
The folder must contain cases.yaml + data/.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from app.core.preprocessors import detect_content_type, preprocess
|
|
||||||
|
|
||||||
_DEFAULT_DIR = Path(__file__).parent / "fixtures" / "preprocessors"
|
|
||||||
|
|
||||||
_GENERATORS = {
|
|
||||||
"binary_noise": "some\x00\x01\x02\x03\x04\x05content" * 20,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _fixtures_dir(config) -> Path:
|
|
||||||
override = config.getoption("--preprocess-dir")
|
|
||||||
return Path(override) if override else _DEFAULT_DIR
|
|
||||||
|
|
||||||
|
|
||||||
def _load_cases(config) -> list[dict]:
|
|
||||||
return yaml.safe_load((_fixtures_dir(config) / "cases.yaml").read_text(encoding="utf-8"))
|
|
||||||
|
|
||||||
|
|
||||||
def _content(case: dict, data_dir: Path) -> str:
|
|
||||||
if "generate" in case:
|
|
||||||
return _GENERATORS[case["generate"]]
|
|
||||||
return (data_dir / case["file"]).read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
# ── parametrize at collection time via pytest hook ────────────────────
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
|
||||||
if "preprocess_case" not in metafunc.fixturenames:
|
|
||||||
return
|
|
||||||
cases = _load_cases(metafunc.config)
|
|
||||||
test_name = metafunc.function.__name__
|
|
||||||
if test_name == "test_detect":
|
|
||||||
subset = [c for c in cases if "detect" in c]
|
|
||||||
else:
|
|
||||||
subset = [c for c in cases if "process" in c]
|
|
||||||
metafunc.parametrize("preprocess_case", subset, ids=[c["id"] for c in subset])
|
|
||||||
|
|
||||||
|
|
||||||
# ── detect ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def test_detect(preprocess_case, pytestconfig) -> None:
|
|
||||||
case = preprocess_case
|
|
||||||
data_dir = _fixtures_dir(pytestconfig) / "data"
|
|
||||||
raw = _content(case, data_dir)
|
|
||||||
filename = case.get("file", "")
|
|
||||||
ct = detect_content_type(filename, raw)
|
|
||||||
expected = case["detect"]
|
|
||||||
assert ct == expected, f"[{case['id']}] expected {expected!r}, got {ct!r}"
|
|
||||||
|
|
||||||
|
|
||||||
# ── preprocess ────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def test_preprocess(preprocess_case, pytestconfig) -> None:
|
|
||||||
case = preprocess_case
|
|
||||||
data_dir = _fixtures_dir(pytestconfig) / "data"
|
|
||||||
raw = _content(case, data_dir)
|
|
||||||
result = preprocess(case["process"], raw)
|
|
||||||
|
|
||||||
if case.get("no_html"):
|
|
||||||
assert not re.search(r"<[^>]+>", result.clean_text), "clean_text contains HTML tags"
|
|
||||||
|
|
||||||
if "min_chars" in case:
|
|
||||||
assert len(result.clean_text) >= case["min_chars"], \
|
|
||||||
f"clean_text too short: {len(result.clean_text)} < {case['min_chars']}"
|
|
||||||
|
|
||||||
if "ratio_lt" in case:
|
|
||||||
ratio = len(result.clean_text) / len(raw)
|
|
||||||
assert ratio < case["ratio_lt"], f"compression ratio {ratio:.2f} >= {case['ratio_lt']}"
|
|
||||||
|
|
||||||
for key in case.get("has_meta", []):
|
|
||||||
assert result.metadata.get(key), f"metadata missing {key!r} (got {result.metadata})"
|
|
||||||
|
|
||||||
for item in ([case["contains"]] if isinstance(case.get("contains"), str) else case.get("contains", [])):
|
|
||||||
assert item in result.clean_text, f"clean_text missing {item!r}"
|
|
||||||
|
|
||||||
for item in ([case["excludes"]] if isinstance(case.get("excludes"), str) else case.get("excludes", [])):
|
|
||||||
assert item not in result.clean_text, f"clean_text contains forbidden {item!r}"
|
|
||||||
|
|
||||||
if "content_type" in case:
|
|
||||||
assert result.content_type == case["content_type"], \
|
|
||||||
f"expected content_type {case['content_type']!r}, got {result.content_type!r}"
|
|
||||||
@@ -4,7 +4,6 @@ import pytest
|
|||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
WsDomain,
|
|
||||||
WsFrameType,
|
WsFrameType,
|
||||||
WsHomeRequest,
|
WsHomeRequest,
|
||||||
WsFloatingDomain,
|
WsFloatingDomain,
|
||||||
@@ -179,15 +178,23 @@ def test_stream_text_deserializes():
|
|||||||
def test_stream_end_defaults():
|
def test_stream_end_defaults():
|
||||||
frame = WsStreamEnd(request_id="r1")
|
frame = WsStreamEnd(request_id="r1")
|
||||||
assert frame.type == WsFrameType.stream_end
|
assert frame.type == WsFrameType.stream_end
|
||||||
|
assert frame.mutations == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_end_with_mutations():
|
||||||
|
mutations = [{"action": "create", "table": "tasks", "data": {"title": "New task"}}]
|
||||||
|
frame = WsStreamEnd(request_id="r1", mutations=mutations)
|
||||||
|
assert len(frame.mutations) == 1
|
||||||
|
assert frame.mutations[0]["action"] == "create"
|
||||||
|
|
||||||
|
|
||||||
def test_stream_end_serializes():
|
def test_stream_end_serializes():
|
||||||
data = WsStreamEnd(request_id="r2").model_dump()
|
data = WsStreamEnd(request_id="r2").model_dump()
|
||||||
assert data == {"type": "stream_end", "request_id": "r2"}
|
assert data == {"type": "stream_end", "request_id": "r2", "mutations": []}
|
||||||
|
|
||||||
|
|
||||||
def test_stream_end_deserializes():
|
def test_stream_end_deserializes():
|
||||||
raw = {"type": "stream_end", "request_id": "r3"}
|
raw = {"type": "stream_end", "request_id": "r3", "mutations": []}
|
||||||
frame = WsStreamEnd.model_validate(raw)
|
frame = WsStreamEnd.model_validate(raw)
|
||||||
assert frame.request_id == "r3"
|
assert frame.request_id == "r3"
|
||||||
|
|
||||||
@@ -196,47 +203,28 @@ def test_stream_end_deserializes():
|
|||||||
|
|
||||||
|
|
||||||
def test_floating_domain_tasks():
|
def test_floating_domain_tasks():
|
||||||
frame = WsFloatingDomain(request_id="r1", domain=WsDomain(type="task"))
|
frame = WsFloatingDomain(request_id="r1", domain="tasks")
|
||||||
assert frame.type == WsFrameType.floating_domain
|
assert frame.type == WsFrameType.floating_domain
|
||||||
assert frame.domain.type == "task"
|
assert frame.domain == "tasks"
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_valid_domains():
|
@pytest.mark.parametrize("domain", ["tasks", "timelines", "notes", "projects"])
|
||||||
frame = WsFloatingDomain(
|
def test_floating_domain_valid_domains(domain: str):
|
||||||
request_id="r1",
|
frame = WsFloatingDomain(request_id="r1", domain=domain) # type: ignore[arg-type]
|
||||||
domain=WsDomain(type="project", id="213213-312321-312312-421321", section="task"),
|
assert frame.domain == domain
|
||||||
)
|
|
||||||
assert frame.domain.type == "project"
|
|
||||||
assert frame.domain.id == "213213-312321-312312-421321"
|
|
||||||
assert frame.domain.section == "task"
|
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_object_valid():
|
def test_floating_domain_invalid():
|
||||||
frame = WsFloatingDomain(
|
with pytest.raises(ValidationError):
|
||||||
request_id="r1",
|
WsFloatingDomain(request_id="r1", domain="invalid") # type: ignore[arg-type]
|
||||||
domain=WsDomain(type="project", id="p1", section="task"),
|
|
||||||
)
|
|
||||||
assert frame.domain.type == "project"
|
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_serializes():
|
def test_floating_domain_serializes():
|
||||||
d = WsFloatingDomain(
|
d = WsFloatingDomain(request_id="r1", domain="notes").model_dump()
|
||||||
request_id="r1",
|
assert d == {"type": "floating_domain", "request_id": "r1", "domain": "notes"}
|
||||||
domain=WsDomain(type="timeline"),
|
|
||||||
).model_dump()
|
|
||||||
assert d == {
|
|
||||||
"type": "floating_domain",
|
|
||||||
"request_id": "r1",
|
|
||||||
"domain": {"type": "timeline", "id": None, "section": None},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_deserializes():
|
def test_floating_domain_deserializes():
|
||||||
raw = {
|
raw = {"type": "floating_domain", "request_id": "r1", "domain": "projects"}
|
||||||
"type": "floating_domain",
|
|
||||||
"request_id": "r1",
|
|
||||||
"domain": {"type": "node", "id": "n-1", "section": None},
|
|
||||||
}
|
|
||||||
frame = WsFloatingDomain.model_validate(raw)
|
frame = WsFloatingDomain.model_validate(raw)
|
||||||
assert frame.domain.type == "node"
|
assert frame.domain == "projects"
|
||||||
assert frame.domain.id == "n-1"
|
|
||||||
|
|||||||
@@ -45,13 +45,15 @@ def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
|
|||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
async def _mock_home_stream(user_id, message, context):
|
async def _mock_home_stream(user_id, message, context, db_session_factory=None):
|
||||||
yield "token", "Hello"
|
yield "token", "Here are your tasks:\n<task>[t1,t2]</task>"
|
||||||
|
yield "mutations", []
|
||||||
|
|
||||||
|
|
||||||
async def _mock_floating_stream(user_id, message, context):
|
async def _mock_floating_stream(user_id, message, context, scope=None, db_session_factory=None):
|
||||||
yield "floating_domain", {"type": "task", "id": None, "section": None}
|
yield "tool_end", {"name": "task_agent", "result": "ok"}
|
||||||
yield "token", "Here is a summary"
|
yield "token", "Here is a summary"
|
||||||
|
yield "mutations", []
|
||||||
|
|
||||||
|
|
||||||
# ── tests ─────────────────────────────────────────────────────────────────────
|
# ── tests ─────────────────────────────────────────────────────────────────────
|
||||||
@@ -102,7 +104,7 @@ def test_floating_request_produces_domain_frame(client):
|
|||||||
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
|
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
|
||||||
|
|
||||||
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
|
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
|
||||||
assert domain_frame["domain"]["type"] == "task"
|
assert domain_frame["domain"] == "tasks"
|
||||||
assert domain_frame["request_id"] == "p1"
|
assert domain_frame["request_id"] == "p1"
|
||||||
|
|
||||||
|
|
||||||
@@ -111,8 +113,9 @@ def test_home_request_request_id_propagated(client):
|
|||||||
token = make_jwt("power", user_id=USER_ID)
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
req_id = "my-unique-req-id"
|
req_id = "my-unique-req-id"
|
||||||
|
|
||||||
async def _stream(user_id, message, context):
|
async def _stream(user_id, message, context, db_session_factory=None):
|
||||||
yield "token", "ok"
|
yield "token", "ok"
|
||||||
|
yield "mutations", []
|
||||||
|
|
||||||
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
|
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
|
||||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
|||||||
Reference in New Issue
Block a user