Compare commits
18 Commits
feature/de
...
297e20ce8d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
297e20ce8d | ||
|
|
5a03bd1cfb | ||
|
|
87b7a1c6c9 | ||
|
|
826f64d6bb | ||
| 5faa6b1d7c | |||
| 02a9684cd6 | |||
| fae9efee0d | |||
| 30b062dd4a | |||
| 2a0331d7ce | |||
| 13fd8677c1 | |||
| 9bd629cb59 | |||
| 9c97702daa | |||
| a1e364c9c0 | |||
| 5b55f1292a | |||
| 5bc9ea6cd6 | |||
| f7404b6f66 | |||
| d667e43c73 | |||
| fe085a7951 |
@@ -0,0 +1,92 @@
|
|||||||
|
"""Deprecate backend agent config tables.
|
||||||
|
|
||||||
|
The Electron client is now the source of truth for agent configuration
|
||||||
|
(directory, extract targets, batch interval, custom prompt). Backend keeps
|
||||||
|
billing checks and trigger/run logs only.
|
||||||
|
|
||||||
|
Revision ID: 9a1f2d0b6c7e
|
||||||
|
Revises: 818478c251dc
|
||||||
|
Create Date: 2026-03-16
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "9a1f2d0b6c7e"
|
||||||
|
down_revision: Union[str, None] = "818478c251dc"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = sa.inspect(bind)
|
||||||
|
existing = set(inspector.get_table_names())
|
||||||
|
|
||||||
|
if "cloud_agent_configs" in existing:
|
||||||
|
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
||||||
|
op.drop_table("cloud_agent_configs")
|
||||||
|
|
||||||
|
if "local_agent_configs" in existing:
|
||||||
|
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
||||||
|
op.drop_table("local_agent_configs")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"local_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("device_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"cloud_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"provider",
|
||||||
|
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
||||||
|
sa.Column("filter_config", sa.JSON, nullable=True),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Agent tool modules — imported by deep_agent.py to build sub-agent graphs."""
|
"""Expose tool modules used by deep orchestrator-worker graphs."""
|
||||||
|
|
||||||
from app.agents import timeline_agent, note_agent, project_agent, task_agent
|
from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent
|
||||||
|
|
||||||
__all__ = ["timeline_agent", "note_agent", "project_agent", "task_agent"]
|
__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"]
|
||||||
|
|||||||
85
app/agents/filesystem_agent.py
Normal file
85
app/agents/filesystem_agent.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""Filesystem agent — tools for reading local directories and files on Electron.
|
||||||
|
|
||||||
|
These tools delegate to the Electron client via ``execute_on_client()`` using
|
||||||
|
the same WS tool-call round-trip pattern as CRUD tools. The Electron app
|
||||||
|
handles actual disk I/O and responds with ``tool_result`` frames.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_directory(path: str) -> str:
|
||||||
|
"""List files and folders in a local directory on the user's device.
|
||||||
|
|
||||||
|
Returns a formatted listing of entries with name, type (file/directory),
|
||||||
|
and full path.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="list_directory",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||||
|
if not entries:
|
||||||
|
return f"Directory '{path}' is empty or does not exist."
|
||||||
|
lines: list[str] = []
|
||||||
|
for entry in entries:
|
||||||
|
entry_type = entry.get("type", "unknown")
|
||||||
|
entry_name = entry.get("name", "")
|
||||||
|
entry_path = entry.get("path", "")
|
||||||
|
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||||
|
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def read_file_content(path: str) -> str:
|
||||||
|
"""Read the text content of a local file on the user's device.
|
||||||
|
|
||||||
|
Returns the file content as a string. Large files may be truncated
|
||||||
|
by the Electron client.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="read_file_content",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
content: str = result.get("content", "")
|
||||||
|
if not content:
|
||||||
|
return f"File '{path}' is empty or could not be read."
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_file_metadata(path: str) -> str:
|
||||||
|
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||||
|
|
||||||
|
Returns a formatted summary of the file's metadata.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="get_file_metadata",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
size = result.get("size", "unknown")
|
||||||
|
created = result.get("createdAt", "unknown")
|
||||||
|
modified = result.get("modifiedAt", "unknown")
|
||||||
|
extension = result.get("extension", "unknown")
|
||||||
|
name = result.get("name", path)
|
||||||
|
return (
|
||||||
|
f"File: {name}\n"
|
||||||
|
f" Extension: {extension}\n"
|
||||||
|
f" Size: {size} bytes\n"
|
||||||
|
f" Created: {created}\n"
|
||||||
|
f" Modified: {modified}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FILESYSTEM_TOOLS: list[Any] = [
|
||||||
|
list_directory,
|
||||||
|
read_file_content,
|
||||||
|
get_file_metadata,
|
||||||
|
]
|
||||||
@@ -1,7 +1,8 @@
|
|||||||
"""Note agent — tool definitions for Markdown note CRUD."""
|
"""Note agent — Markdown note management (list, get, create, update, delete)."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
@@ -9,14 +10,38 @@ from langchain_core.tools import tool
|
|||||||
from app.core.llm import embed
|
from app.core.llm import embed
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
NOTE_SYSTEM_PROMPT = (
|
||||||
|
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
||||||
|
"and delete Markdown notes in their workspace.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - content is always Markdown; preserve formatting when updating\n"
|
||||||
|
" - project_id is optional; link a note to a project when mentioned\n"
|
||||||
|
" - When updating, call get_note first if you need to read existing content\n"
|
||||||
|
" before appending or replacing sections\n"
|
||||||
|
" - list_notes without project_id returns all notes; scope with project_id\n"
|
||||||
|
" when the user is working within a specific project\n"
|
||||||
|
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
|
||||||
|
" - Do not fabricate note content — reflect what the user provides or what\n"
|
||||||
|
" is already in the note (retrieved via get_note)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_notes(project_id: str = "") -> str:
|
async def list_notes(project_id: str = "") -> str:
|
||||||
"""List notes, optionally scoped to a project by project_id."""
|
"""List notes, optionally scoped to a project by project_id."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="notes",
|
table="notes",
|
||||||
filters={"projectId": project_id or None},
|
filters={"projectId": normalized_project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
@@ -105,4 +130,10 @@ async def delete_note(note_id: str) -> str:
|
|||||||
return f"Note {note_id} deleted."
|
return f"Note {note_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
NOTE_TOOLS: list[Any] = [
|
||||||
|
list_notes,
|
||||||
|
get_note,
|
||||||
|
create_note,
|
||||||
|
update_note,
|
||||||
|
delete_note,
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Project agent — tool definitions for project lifecycle CRUD."""
|
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -8,6 +8,22 @@ from langchain_core.tools import tool
|
|||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
PROJECT_SYSTEM_PROMPT = (
|
||||||
|
"You are a project management assistant. You help users create, find,\n"
|
||||||
|
"update, and archive projects in their workspace.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - status must be one of: active, archived\n"
|
||||||
|
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
||||||
|
" - ai_summary is populated only when the user asks for a project summary;\n"
|
||||||
|
" derive it from context data — do not fabricate content\n"
|
||||||
|
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
||||||
|
" user wants a complete cross-client view including archived projects\n"
|
||||||
|
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
||||||
|
" list_projects if you only have a project name\n"
|
||||||
|
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
||||||
|
" only call delete_project when the user explicitly confirms deletion."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_projects(
|
async def list_projects(
|
||||||
@@ -117,4 +133,11 @@ async def delete_project(project_id: str) -> str:
|
|||||||
return f"Project {project_id} permanently deleted."
|
return f"Project {project_id} permanently deleted."
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_TOOLS: list[Any] = [
|
||||||
|
list_projects,
|
||||||
|
list_all_projects,
|
||||||
|
get_project,
|
||||||
|
create_project,
|
||||||
|
update_project,
|
||||||
|
delete_project,
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,14 +1,40 @@
|
|||||||
"""Task agent — tool definitions for task and task comment CRUD."""
|
"""Task agent — full CRUD for tasks and task comments."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
TASK_SYSTEM_PROMPT = (
|
||||||
|
"You are a task management assistant for a project workspace.\n"
|
||||||
|
"You create, update, list, and track tasks and their comments.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - status must be one of: todo, in_progress, done\n"
|
||||||
|
" - priority must be one of: high, medium, low\n"
|
||||||
|
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
||||||
|
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
||||||
|
" - project_id is optional; link to a project when the user mentions one\n"
|
||||||
|
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
||||||
|
" did not explicitly request; 0 otherwise\n"
|
||||||
|
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
|
||||||
|
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
||||||
|
" - For update_task, use -1 for integer fields you do not want to change\n"
|
||||||
|
" - Always confirm the action in plain, user-friendly language."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Task tools ────────────────────────────────────────────────────────
|
# ── Task tools ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -22,11 +48,12 @@ async def list_tasks(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
filters={
|
filters={
|
||||||
"projectId": project_id or None,
|
"projectId": normalized_project_id or None,
|
||||||
"status": status or None,
|
"status": status or None,
|
||||||
"search": search or None,
|
"search": search or None,
|
||||||
"orderBy": order_by or None,
|
"orderBy": order_by or None,
|
||||||
@@ -188,8 +215,12 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
|||||||
table="taskComments",
|
table="taskComments",
|
||||||
data={"taskId": task_id, "author": author, "content": content},
|
data={"taskId": task_id, "author": author, "content": content},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result.get("row", {})
|
||||||
return f"Comment added by {row['author']} on task {row['taskId']} (comment id: {row['id']})."
|
row_author = row.get("author", author)
|
||||||
|
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
|
||||||
|
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
||||||
|
row_comment_id = row.get("id", "unknown")
|
||||||
|
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -199,4 +230,16 @@ async def delete_task_comment(comment_id: str) -> str:
|
|||||||
return f"Comment {comment_id} deleted."
|
return f"Comment {comment_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
# ── Agent ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
TASK_TOOLS: list[Any] = [
|
||||||
|
list_tasks,
|
||||||
|
create_task,
|
||||||
|
update_task,
|
||||||
|
delete_task,
|
||||||
|
list_tasks_due_today,
|
||||||
|
list_task_comments,
|
||||||
|
add_task_comment,
|
||||||
|
delete_task_comment,
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,21 +1,45 @@
|
|||||||
"""Timeline agent — tool definitions for project milestone CRUD."""
|
"""Timeline agent — project milestone management (list, create, update, delete)."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
_UUID_RE = re.compile(
|
||||||
|
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_uuid(value: str) -> bool:
|
||||||
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
|
TIMELINE_SYSTEM_PROMPT = (
|
||||||
|
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
||||||
|
"track progress on a project — they are not calendar events.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
||||||
|
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
|
||||||
|
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
||||||
|
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
||||||
|
" - is_approved: 0 until the user explicitly confirms; then 1\n"
|
||||||
|
" - For update_timeline, use -1 for integer fields you do not want to change\n"
|
||||||
|
" - Listing without a project_id returns all timelines across projects\n"
|
||||||
|
" - Always echo the title and formatted date in your confirmation."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_timelines(project_id: str = "") -> str:
|
async def list_timelines(project_id: str = "") -> str:
|
||||||
"""List timelines. Provide project_id to scope to a specific project."""
|
"""List timelines. Provide project_id to scope to a specific project."""
|
||||||
|
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="timelines",
|
table="timelines",
|
||||||
filters={"projectId": project_id or None},
|
filters={"projectId": normalized_project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
@@ -89,4 +113,9 @@ async def delete_timeline(timeline_id: str) -> str:
|
|||||||
return f"Timeline {timeline_id} deleted."
|
return f"Timeline {timeline_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
TIMELINE_TOOLS: list[Any] = [
|
||||||
|
list_timelines,
|
||||||
|
create_timeline,
|
||||||
|
update_timeline,
|
||||||
|
delete_timeline,
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,54 +1,40 @@
|
|||||||
"""Chatbot Journey endpoints — guided conversation to build an agent prompt_template.
|
"""Chatbot Journey — WS-based guided conversation to build an agent prompt_template.
|
||||||
|
|
||||||
Endpoints:
|
The journey is driven entirely through WebSocket frames (no REST endpoints).
|
||||||
POST /agents/journey/start — start a new journey session
|
The device WS handler dispatches ``journey_start`` and ``journey_message``
|
||||||
POST /agents/journey/message — continue the conversation
|
frames to the functions exported here.
|
||||||
|
|
||||||
Sessions are stored in-memory with a 30-minute TTL. Stale entries are
|
|
||||||
cleaned up lazily on access. Upgrade to Redis for multi-instance deployments.
|
|
||||||
|
|
||||||
Journey flow:
|
Journey flow:
|
||||||
1. Client sends ``{ agent_type, agent_id? }`` to ``/start``.
|
1. FE sends ``journey_start`` frame with basic agent config (directory,
|
||||||
2. Server creates a session, calls the LLM with a contextual system prompt,
|
data_types, schedule).
|
||||||
and returns the first question.
|
2. Server creates an in-memory session, sets up a WS executor so the
|
||||||
3. Client sends follow-up messages to ``/message``.
|
setup LLM can use file-system tools, does a first directory scrape,
|
||||||
4. After 3-5 turns the LLM wraps up by emitting a ``prompt_template`` block
|
and sends back a ``journey_reply`` with the first question.
|
||||||
delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
3. FE sends ``journey_message`` frames for each user reply.
|
||||||
5. Server parses the block, sets ``done=True``, and returns the template.
|
4. Server appends the user message, calls the LLM (which may read files
|
||||||
|
via tools), and sends back a ``journey_reply``.
|
||||||
The ``prompt_template`` from the final response is meant to be stored in
|
5. After 3-5 turns the LLM wraps up by emitting a ``prompt_template``
|
||||||
``LocalAgentConfig.prompt_template`` or ``CloudAgentConfig.prompt_template``
|
block delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
||||||
by the Electron client (via the agent CRUD endpoints).
|
6. Server parses the block, sends ``journey_reply`` with ``done=True``
|
||||||
|
and the template. FE stores it locally.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||||||
from app.core.llm import get_llm
|
from app.core.llm import get_llm
|
||||||
from app.db import get_session
|
|
||||||
from app.models import CloudAgentConfig, LocalAgentConfig
|
|
||||||
from app.schemas import (
|
|
||||||
JourneyMessageRequest,
|
|
||||||
JourneyResponse,
|
|
||||||
JourneyStartRequest,
|
|
||||||
UserProfile,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/agents/journey", tags=["agents"])
|
|
||||||
|
|
||||||
# ── Session TTL ───────────────────────────────────────────────────────────
|
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||||
@@ -59,16 +45,21 @@ _TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
|||||||
|
|
||||||
# Maximum number of conversation turns before the LLM is nudged to wrap up.
|
# Maximum number of conversation turns before the LLM is nudged to wrap up.
|
||||||
_MAX_TURNS: int = 5
|
_MAX_TURNS: int = 5
|
||||||
|
# Max tool-calling steps per LLM invocation.
|
||||||
|
_MAX_TOOL_STEPS: int = 6
|
||||||
|
|
||||||
# ── In-memory session store ───────────────────────────────────────────────
|
# ── In-memory session store ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _JourneySession:
|
class JourneySession:
|
||||||
session_id: str
|
session_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
agent_type: str # "local" | "cloud"
|
agent_type: str # "local" | "cloud"
|
||||||
|
directory: str
|
||||||
|
data_types: list[str]
|
||||||
history: list[dict[str, Any]] = field(default_factory=list)
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
system_prompt: str = ""
|
||||||
created_at: float = field(default_factory=time.monotonic)
|
created_at: float = field(default_factory=time.monotonic)
|
||||||
|
|
||||||
def is_expired(self) -> bool:
|
def is_expired(self) -> bool:
|
||||||
@@ -76,67 +67,77 @@ class _JourneySession:
|
|||||||
|
|
||||||
|
|
||||||
# session_id → session
|
# session_id → session
|
||||||
_sessions: dict[str, _JourneySession] = {}
|
_sessions: dict[str, JourneySession] = {}
|
||||||
|
|
||||||
|
|
||||||
def _get_session(session_id: str, user_id: str) -> _JourneySession:
|
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
||||||
"""Retrieve session; raise 404 on missing, expired, or wrong owner."""
|
"""Retrieve session; return None on missing, expired, or wrong owner."""
|
||||||
s = _sessions.get(session_id)
|
s = _sessions.get(session_id)
|
||||||
if s is None or s.is_expired():
|
if s is None or s.is_expired():
|
||||||
_sessions.pop(session_id, None)
|
_sessions.pop(session_id, None)
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
return None
|
||||||
if s.user_id != user_id:
|
if s.user_id != user_id:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
return None
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
# ── System prompt builder ─────────────────────────────────────────────────
|
# ── System prompt builder ─────────────────────────────────────────────────
|
||||||
|
|
||||||
_LOCAL_PREAMBLE = """\
|
|
||||||
What kind of files are in the directories you want to monitor? \
|
|
||||||
(for example: emails saved as .eml, documents in .pdf or .txt, markdown notes, etc.)"""
|
|
||||||
|
|
||||||
_CLOUD_PREAMBLE = """\
|
|
||||||
What kind of emails or messages should I look for? \
|
|
||||||
(for example: client communications, invoices, meeting notes, project updates, etc.)"""
|
|
||||||
|
|
||||||
_SYSTEM_PROMPT_TEMPLATE = """\
|
_SYSTEM_PROMPT_TEMPLATE = """\
|
||||||
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||||
Your job is to understand exactly what data the user wants to extract from their {source_description} \
|
Your job is to understand exactly what data the user wants to extract from their
|
||||||
and produce a detailed prompt_template that a separate AI will use as its instruction set.
|
local directory and produce a detailed prompt_template that a separate AI will use
|
||||||
|
as its instruction set.
|
||||||
|
|
||||||
Ask concise, focused questions one at a time. Cover these topics (not necessarily in this order):
|
The extraction agent already has this base behaviour built in:
|
||||||
1. The type and format of the source content.
|
- Reads each file using file-system tools.
|
||||||
2. Which data types to extract: tasks, notes, timelines, and/or projects.
|
- Creates records (tasks, notes, timelines, projects) via CRUD tools.
|
||||||
3. How fields should be mapped (e.g. email subject → task title).
|
- Sets isAiSuggested=1 and isApproved=0 on every record.
|
||||||
4. Priority or status rules (e.g. "urgent" keyword → high priority).
|
- Only extracts data explicitly present in the files — it never invents information.
|
||||||
5. Any special handling, date extraction, or exclusions.
|
The user's custom prompt is appended AFTER this base behaviour, so focus on
|
||||||
|
what to look for and how to map it — not on the general extraction mechanics.
|
||||||
|
|
||||||
After 3-5 questions (when you have enough information), output the final prompt_template between \
|
You have access to file-system tools to explore the user's directory:
|
||||||
these exact markers on their own lines:
|
- list_directory: to see folder structure
|
||||||
|
- read_file_content: to peek at file contents
|
||||||
|
- get_file_metadata: to check file info
|
||||||
|
|
||||||
|
The user's configured directory is: {directory}
|
||||||
|
Target data types: {data_types}
|
||||||
|
|
||||||
|
Start by exploring the directory to understand its structure. Then ask concise,
|
||||||
|
focused questions one at a time. Cover these topics (not necessarily in this order):
|
||||||
|
1. The type and format of the source content (confirmed by your exploration).
|
||||||
|
2. How fields should be mapped (e.g. filename → task title).
|
||||||
|
3. Priority or status rules (e.g. "urgent" keyword → high priority).
|
||||||
|
4. Any special handling, date extraction, or exclusions.
|
||||||
|
|
||||||
|
After 3-5 questions (when you have enough information), output the final prompt_template
|
||||||
|
between these exact markers on their own lines:
|
||||||
|
|
||||||
{template_start}
|
{template_start}
|
||||||
<the complete extraction prompt here>
|
<the complete extraction prompt here>
|
||||||
{template_end}
|
{template_end}
|
||||||
|
|
||||||
The prompt_template must be a self-contained instruction for an AI that receives a document/email/message \
|
The prompt_template must be a self-contained instruction for an AI that reads files
|
||||||
and must return a JSON array of records in this shape:
|
and must perform CRUD operations using tools to create records. It should specify:
|
||||||
[{{ "table": "<tasks|notes|timelines|projects>", "data": {{ <field: value> }} }}, ...]
|
- What entity types to create (tasks, notes, timelines, projects).
|
||||||
|
- How to map file content to record fields (camelCase: title, status, priority,
|
||||||
|
dueDate, projectId, content, etc.).
|
||||||
|
- That isAiSuggested must be set to 1 and isApproved to 0 on every record.
|
||||||
|
- Concrete examples of mappings based on what you discovered in the directory.
|
||||||
|
|
||||||
Rules for the generated template:
|
|
||||||
- Be explicit about field names (camelCase: title, status, priority, dueDate, projectId, content, etc.).
|
|
||||||
- Include concrete examples of mappings.
|
|
||||||
- Mention that Electron adds id/createdAt/updatedAt automatically.
|
|
||||||
- Set isAiSuggested: true and isApproved: false on every record.
|
|
||||||
{existing_section}\
|
{existing_section}\
|
||||||
Do not ask more than {max_turns} questions total. Start with your first question now.\
|
Do not ask more than {max_turns} questions total. Begin by exploring the directory,
|
||||||
|
then ask your first question.\
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
|
def _build_system_prompt(
|
||||||
source_description = (
|
directory: str,
|
||||||
"files in local directories" if agent_type == "local" else "emails and messages from cloud providers"
|
data_types: list[str],
|
||||||
)
|
existing_template: str | None = None,
|
||||||
|
) -> str:
|
||||||
existing_section = (
|
existing_section = (
|
||||||
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
||||||
f"---\n{existing_template}\n---\n"
|
f"---\n{existing_template}\n---\n"
|
||||||
@@ -144,7 +145,8 @@ def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
|
|||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
return _SYSTEM_PROMPT_TEMPLATE.format(
|
return _SYSTEM_PROMPT_TEMPLATE.format(
|
||||||
source_description=source_description,
|
directory=directory,
|
||||||
|
data_types=", ".join(data_types),
|
||||||
template_start=_TEMPLATE_START,
|
template_start=_TEMPLATE_START,
|
||||||
template_end=_TEMPLATE_END,
|
template_end=_TEMPLATE_END,
|
||||||
existing_section=existing_section,
|
existing_section=existing_section,
|
||||||
@@ -152,10 +154,6 @@ def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _first_question(agent_type: str) -> str:
|
|
||||||
return _LOCAL_PREAMBLE if agent_type == "local" else _CLOUD_PREAMBLE
|
|
||||||
|
|
||||||
|
|
||||||
# ── Template extraction ───────────────────────────────────────────────────
|
# ── Template extraction ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -168,11 +166,37 @@ def _extract_template(text: str) -> str | None:
|
|||||||
return text[start_idx:end_idx].strip() or None
|
return text[start_idx:end_idx].strip() or None
|
||||||
|
|
||||||
|
|
||||||
# ── LLM call ─────────────────────────────────────────────────────────────
|
# ── LLM call with tool support ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
|
def _as_text(content: Any) -> str:
|
||||||
"""Build LangChain messages from history and invoke the LLM."""
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_llm_with_tools(
|
||||||
|
system_prompt: str,
|
||||||
|
history: list[dict[str, Any]],
|
||||||
|
tools: list[Any],
|
||||||
|
) -> str:
|
||||||
|
"""Build LangChain messages from history and invoke the LLM with tools.
|
||||||
|
|
||||||
|
Handles tool-calling loops: if the LLM calls tools, execute them and
|
||||||
|
continue until a final text response is produced.
|
||||||
|
"""
|
||||||
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||||
for turn in history:
|
for turn in history:
|
||||||
if turn["role"] == "user":
|
if turn["role"] == "user":
|
||||||
@@ -181,137 +205,194 @@ async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
|
|||||||
messages.append(AIMessage(content=turn["content"]))
|
messages.append(AIMessage(content=turn["content"]))
|
||||||
|
|
||||||
llm = get_llm(model=None, temperature=0.4)
|
llm = get_llm(model=None, temperature=0.4)
|
||||||
response = await llm.ainvoke(messages)
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
return response.content # type: ignore[return-value]
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
for _ in range(_MAX_TOOL_STEPS):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
return _as_text(response.content)
|
||||||
|
|
||||||
|
for call in response.tool_calls:
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"agent_setup: journey tool_call name=%s args=%s",
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:500],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_setup: journey tool_result name=%s output=%s",
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:800],
|
||||||
|
)
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
# Fallback: exceeded max steps.
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
return _as_text(final.content)
|
||||||
|
|
||||||
|
|
||||||
# ── Existing-config loader ────────────────────────────────────────────────
|
# ── Journey handlers (called from device_ws.py) ──────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def _load_existing_template(
|
async def handle_journey_start(
|
||||||
agent_id: str,
|
|
||||||
user_id: str,
|
user_id: str,
|
||||||
db: AsyncSession,
|
frame: dict[str, Any],
|
||||||
) -> str | None:
|
) -> dict[str, Any]:
|
||||||
"""Return the prompt_template of an existing agent config, or None."""
|
"""Handle a ``journey_start`` WS frame.
|
||||||
# Try local first, then cloud.
|
|
||||||
local_result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(
|
|
||||||
LocalAgentConfig.id == agent_id,
|
|
||||||
LocalAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
local = local_result.scalar_one_or_none()
|
|
||||||
if local is not None:
|
|
||||||
return local.prompt_template
|
|
||||||
|
|
||||||
cloud_result = await db.execute(
|
Creates a session, runs the setup LLM with directory exploration,
|
||||||
select(CloudAgentConfig).where(
|
and returns the ``journey_reply`` payload.
|
||||||
CloudAgentConfig.id == agent_id,
|
|
||||||
CloudAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cloud = cloud_result.scalar_one_or_none()
|
|
||||||
return cloud.prompt_template if cloud is not None else None
|
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/start", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
|
||||||
async def start_journey(
|
|
||||||
body: JourneyStartRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> JourneyResponse:
|
|
||||||
"""Start a new Chatbot Journey session.
|
|
||||||
|
|
||||||
If ``agent_id`` is provided the session is pre-seeded with the existing
|
|
||||||
agent's ``prompt_template`` so the user can refine it.
|
|
||||||
"""
|
"""
|
||||||
# Load existing template (may be None).
|
agent_type = frame.get("agent_type", "local")
|
||||||
existing_template: str | None = None
|
directory = frame.get("directory", "")
|
||||||
if body.agent_id:
|
data_types = frame.get("data_types", [])
|
||||||
existing_template = await _load_existing_template(body.agent_id, current_user.id, db)
|
existing_template = frame.get("existing_template")
|
||||||
# If agent_id was given but not found, proceed without seeding (don't 404 —
|
|
||||||
# the user may be starting a fresh journey for a not-yet-persisted config).
|
|
||||||
|
|
||||||
system_prompt = _build_system_prompt(body.agent_type, existing_template)
|
# Use the session_id provided by the FE so the reply matches the
|
||||||
first_question = _first_question(body.agent_type)
|
# listener key; fall back to a generated one if absent.
|
||||||
|
session_id = frame.get("session_id") or str(uuid.uuid4())
|
||||||
|
system_prompt = _build_system_prompt(directory, data_types, existing_template)
|
||||||
|
|
||||||
session_id = str(uuid.uuid4())
|
session = JourneySession(
|
||||||
session = _JourneySession(
|
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=current_user.id,
|
user_id=user_id,
|
||||||
agent_type=body.agent_type,
|
agent_type=agent_type,
|
||||||
# Seed history with the AI's first question so it stays consistent.
|
directory=directory,
|
||||||
history=[{"role": "assistant", "content": first_question}],
|
data_types=data_types,
|
||||||
|
system_prompt=system_prompt,
|
||||||
)
|
)
|
||||||
# Store the system prompt inside the session for reuse in /message.
|
|
||||||
session.__dict__["_system_prompt"] = system_prompt # type: ignore[index]
|
# The LLM will explore the directory using FILESYSTEM_TOOLS via the
|
||||||
|
# ws_context executor (already set by the WS handler before calling us).
|
||||||
|
# Seed with an initial user message — some providers (e.g. GitHub Copilot)
|
||||||
|
# require at least one user/input message to be present.
|
||||||
|
seed_history: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
||||||
|
]
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history=seed_history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.extend(seed_history)
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
_sessions[session_id] = session
|
_sessions[session_id] = session
|
||||||
|
|
||||||
logger.info("Journey session %s started for user %s (agent_type=%s)", session_id, current_user.id, body.agent_type)
|
logger.info(
|
||||||
return JourneyResponse(session_id=session_id, message=first_question, done=False)
|
"agent_setup: journey session %s started for user %s (directory=%s)",
|
||||||
|
session_id,
|
||||||
|
user_id,
|
||||||
|
directory,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the LLM produced the template on the first turn (unlikely but possible).
|
||||||
@router.post("/message", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
|
||||||
async def send_journey_message(
|
|
||||||
body: JourneyMessageRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> JourneyResponse:
|
|
||||||
"""Send a message in an existing Chatbot Journey session.
|
|
||||||
|
|
||||||
The server appends the user's message to the conversation history,
|
|
||||||
calls the LLM, and appends the AI reply. When the LLM wraps up with a
|
|
||||||
``prompt_template`` block the response includes ``done=True`` and the
|
|
||||||
extracted template.
|
|
||||||
"""
|
|
||||||
session = _get_session(body.session_id, current_user.id)
|
|
||||||
system_prompt: str = session.__dict__.get("_system_prompt", _build_system_prompt(session.agent_type, None)) # type: ignore[assignment]
|
|
||||||
|
|
||||||
# Append user turn to history.
|
|
||||||
session.history.append({"role": "user", "content": body.message})
|
|
||||||
|
|
||||||
# Call the LLM with the full conversation so far.
|
|
||||||
ai_reply = await _call_llm(system_prompt, session.history)
|
|
||||||
|
|
||||||
# Append AI turn.
|
|
||||||
session.history.append({"role": "assistant", "content": ai_reply})
|
|
||||||
|
|
||||||
# Check if the LLM produced the final template.
|
|
||||||
prompt_template = _extract_template(ai_reply)
|
prompt_template = _extract_template(ai_reply)
|
||||||
done = prompt_template is not None
|
done = prompt_template is not None
|
||||||
|
|
||||||
# Strip the sentinel markers from the message shown to the user.
|
|
||||||
display_message = ai_reply
|
display_message = ai_reply
|
||||||
if done:
|
if done:
|
||||||
display_message = (
|
display_message = (
|
||||||
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
or "Here is your agent configuration. You can save it or continue refining."
|
or "Here is your agent configuration. You can save it or continue refining."
|
||||||
)
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
|
||||||
if done:
|
return {
|
||||||
logger.info("Journey session %s completed for user %s", body.session_id, current_user.id)
|
"type": "journey_reply",
|
||||||
# Clean up the session immediately on completion.
|
"session_id": session_id,
|
||||||
_sessions.pop(body.session_id, None)
|
"message": display_message,
|
||||||
else:
|
"done": done,
|
||||||
# Nudge the LLM to wrap up after max turns.
|
"prompt_template": prompt_template,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_journey_message(
|
||||||
|
user_id: str,
|
||||||
|
frame: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Handle a ``journey_message`` WS frame.
|
||||||
|
|
||||||
|
Appends the user message, calls the LLM, and returns the
|
||||||
|
``journey_reply`` payload.
|
||||||
|
"""
|
||||||
|
session_id = frame.get("session_id", "")
|
||||||
|
message = frame.get("message", "")
|
||||||
|
|
||||||
|
session = get_journey_session(session_id, user_id)
|
||||||
|
if session is None:
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": "Journey session not found or expired. Please start a new setup.",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Append user turn.
|
||||||
|
session.history.append({"role": "user", "content": message})
|
||||||
|
|
||||||
|
# Call the LLM with tools.
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=session.system_prompt,
|
||||||
|
history=session.history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
|
||||||
|
# Check if the LLM produced the final template.
|
||||||
|
prompt_template = _extract_template(ai_reply)
|
||||||
|
done = prompt_template is not None
|
||||||
|
|
||||||
|
# If the LLM didn't produce a template but we've hit max turns, nudge it
|
||||||
|
# and call the LLM one more time to force template generation.
|
||||||
|
if not done:
|
||||||
turns = sum(1 for t in session.history if t["role"] == "user")
|
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||||
if turns >= _MAX_TURNS:
|
if turns >= _MAX_TURNS:
|
||||||
# Add a system-level nudge as a hidden user message.
|
nudge_content = (
|
||||||
session.history.append({
|
|
||||||
"role": "user",
|
|
||||||
"content": (
|
|
||||||
"[System: You have enough information. Please generate the final "
|
"[System: You have enough information. Please generate the final "
|
||||||
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
||||||
),
|
|
||||||
})
|
|
||||||
|
|
||||||
return JourneyResponse(
|
|
||||||
session_id=body.session_id,
|
|
||||||
message=display_message,
|
|
||||||
done=done,
|
|
||||||
prompt_template=prompt_template,
|
|
||||||
)
|
)
|
||||||
|
session.history.append({"role": "user", "content": nudge_content})
|
||||||
|
|
||||||
|
nudge_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=session.system_prompt,
|
||||||
|
history=session.history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
)
|
||||||
|
session.history.append({"role": "assistant", "content": nudge_reply})
|
||||||
|
|
||||||
|
prompt_template = _extract_template(nudge_reply)
|
||||||
|
if prompt_template is not None:
|
||||||
|
done = True
|
||||||
|
ai_reply = nudge_reply
|
||||||
|
|
||||||
|
display_message = ai_reply
|
||||||
|
if done:
|
||||||
|
display_message = (
|
||||||
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
|
if _TEMPLATE_START in ai_reply
|
||||||
|
else "Here is your agent configuration. You can save it or continue refining."
|
||||||
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": display_message,
|
||||||
|
"done": done,
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,45 +1,36 @@
|
|||||||
"""Agent CRUD routes: local directory agents and cloud connector agents.
|
"""Agent routes.
|
||||||
|
|
||||||
Endpoints:
|
Backend responsibilities are intentionally minimal:
|
||||||
GET /agents/catalog — hardcoded agent type catalog
|
GET /agents/catalog — static catalog for UI display
|
||||||
GET /agents/local — list user's local agent configs
|
POST /agents/can-create — billing eligibility check
|
||||||
POST /agents/local — create local agent (tier-gated)
|
POST /agents/trigger — trigger a local agent run
|
||||||
PUT /agents/local/{agent_id} — partial update (ownership check)
|
|
||||||
DELETE /agents/local/{agent_id} — delete + cascade run logs
|
Agent configuration is owned by the Electron app and is not persisted
|
||||||
GET /agents/cloud — list user's cloud agent configs
|
in backend agent-config tables.
|
||||||
POST /agents/cloud — create cloud agent (tier-gated)
|
|
||||||
PUT /agents/cloud/{agent_id} — partial update (ownership check)
|
|
||||||
DELETE /agents/cloud/{agent_id} — delete + cascade run logs
|
|
||||||
GET /agents/runs — paginated run logs (agent_id, page, limit)
|
|
||||||
POST /agents/{agent_id}/run — manual trigger stub (dispatch in Step 3.4)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
import uuid
|
||||||
from typing import Any
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from pydantic import BaseModel
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy import func, or_, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.billing.tier_manager import FEATURES
|
from app.billing.tier_manager import FEATURES
|
||||||
from app.core.agent_runner import run_cloud_agent, run_local_agent
|
from app.core.agent_runner import run_local_agent
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
from app.models import AgentRunLog, LocalAgentConfig
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
AgentCatalogItem,
|
AgentCatalogItem,
|
||||||
|
AgentCreationCheckRequest,
|
||||||
|
AgentCreationCheckResponse,
|
||||||
AgentRunLogResponse,
|
AgentRunLogResponse,
|
||||||
CloudAgentConfigCreate,
|
AgentTriggerRequest,
|
||||||
CloudAgentConfigResponse,
|
|
||||||
CloudAgentConfigUpdate,
|
|
||||||
LocalAgentConfigCreate,
|
|
||||||
LocalAgentConfigResponse,
|
|
||||||
LocalAgentConfigUpdate,
|
|
||||||
UserProfile,
|
UserProfile,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -56,39 +47,21 @@ def _dt_ms_opt(dt: datetime | None) -> int | None:
|
|||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
# ── Model → schema converters ─────────────────────────────────────────
|
def _to_data_types(values: list[str]) -> list[str]:
|
||||||
|
normalize = {
|
||||||
def _to_local_response(a: LocalAgentConfig) -> LocalAgentConfigResponse:
|
"task": "tasks", "tasks": "tasks",
|
||||||
return LocalAgentConfigResponse(
|
"note": "notes", "notes": "notes",
|
||||||
id=a.id,
|
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
||||||
name=a.name,
|
"project": "projects", "projects": "projects",
|
||||||
device_id=a.device_id,
|
}
|
||||||
directory_paths=a.directory_paths,
|
seen: set[str] = set()
|
||||||
data_types=a.data_types,
|
result: list[str] = []
|
||||||
prompt_template=a.prompt_template,
|
for v in values:
|
||||||
file_extensions=a.file_extensions,
|
mapped = normalize.get(v)
|
||||||
schedule_cron=a.schedule_cron,
|
if mapped and mapped not in seen:
|
||||||
enabled=a.enabled,
|
seen.add(mapped)
|
||||||
last_run_at=_dt_ms_opt(a.last_run_at),
|
result.append(mapped)
|
||||||
created_at=_dt_ms(a.created_at),
|
return result
|
||||||
updated_at=_dt_ms(a.updated_at),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _to_cloud_response(a: CloudAgentConfig) -> CloudAgentConfigResponse:
|
|
||||||
return CloudAgentConfigResponse(
|
|
||||||
id=a.id,
|
|
||||||
provider=a.provider, # type: ignore[arg-type]
|
|
||||||
name=a.name,
|
|
||||||
data_types=a.data_types,
|
|
||||||
prompt_template=a.prompt_template,
|
|
||||||
schedule_cron=a.schedule_cron,
|
|
||||||
filter_config=a.filter_config,
|
|
||||||
enabled=a.enabled,
|
|
||||||
last_run_at=_dt_ms_opt(a.last_run_at),
|
|
||||||
created_at=_dt_ms(a.created_at),
|
|
||||||
updated_at=_dt_ms(a.updated_at),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
||||||
@@ -105,77 +78,42 @@ def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Ownership-checked lookups ─────────────────────────────────────────
|
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
||||||
|
|
||||||
async def _get_local_agent_for_user(
|
|
||||||
agent_id: str, user_id: str, db: AsyncSession
|
|
||||||
) -> LocalAgentConfig:
|
|
||||||
result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(
|
|
||||||
LocalAgentConfig.id == agent_id,
|
|
||||||
LocalAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
record = result.scalar_one_or_none()
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_cloud_agent_for_user(
|
|
||||||
agent_id: str, user_id: str, db: AsyncSession
|
|
||||||
) -> CloudAgentConfig:
|
|
||||||
result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(
|
|
||||||
CloudAgentConfig.id == agent_id,
|
|
||||||
CloudAgentConfig.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
record = result.scalar_one_or_none()
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tier limit helper ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _count_enabled_agents(user_id: str, db: AsyncSession) -> int:
|
|
||||||
"""Return combined enabled local + cloud agent count for the user."""
|
|
||||||
local_count = (
|
|
||||||
await db.execute(
|
|
||||||
select(func.count(LocalAgentConfig.id)).where(
|
|
||||||
LocalAgentConfig.user_id == user_id,
|
|
||||||
LocalAgentConfig.enabled == True, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).scalar_one()
|
|
||||||
cloud_count = (
|
|
||||||
await db.execute(
|
|
||||||
select(func.count(CloudAgentConfig.id)).where(
|
|
||||||
CloudAgentConfig.user_id == user_id,
|
|
||||||
CloudAgentConfig.enabled == True, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).scalar_one()
|
|
||||||
return local_count + cloud_count
|
|
||||||
|
|
||||||
|
|
||||||
def _enforce_agent_limit(tier: str, current_count: int) -> None:
|
|
||||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||||
if limit != -1 and current_count >= limit:
|
if limit != -1 and current_count >= limit:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||||
)
|
)
|
||||||
|
return limit
|
||||||
|
|
||||||
|
|
||||||
# ── Local page schema (used by runs endpoint) ─────────────────────────
|
async def _enforce_run_frequency(
|
||||||
|
tier: str,
|
||||||
|
user_id: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
|
||||||
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
||||||
|
if limit == -1:
|
||||||
|
return # unlimited
|
||||||
|
|
||||||
class _RunsPage(BaseModel):
|
today_start = datetime.now(timezone.utc).replace(
|
||||||
total: int
|
hour=0, minute=0, second=0, microsecond=0
|
||||||
page: int
|
)
|
||||||
limit: int
|
result = await db.execute(
|
||||||
items: list[AgentRunLogResponse]
|
select(func.count(AgentRunLog.id)).where(
|
||||||
|
AgentRunLog.user_id == user_id,
|
||||||
|
AgentRunLog.started_at >= today_start,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
runs_today: int = result.scalar_one()
|
||||||
|
|
||||||
|
if runs_today >= limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Catalog ───────────────────────────────────────────────────────────
|
# ── Catalog ───────────────────────────────────────────────────────────
|
||||||
@@ -209,229 +147,52 @@ async def get_agent_catalog(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# ── Local agent CRUD ──────────────────────────────────────────────────
|
@router.post("/can-create", response_model=AgentCreationCheckResponse)
|
||||||
|
async def can_create_agent(
|
||||||
@router.get("/local", response_model=list[LocalAgentConfigResponse])
|
body: AgentCreationCheckRequest,
|
||||||
async def list_local_agents(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_session),
|
) -> AgentCreationCheckResponse:
|
||||||
) -> list[LocalAgentConfigResponse]:
|
"""Check if the user can create one more agent based on billing tier.
|
||||||
"""List all local directory agent configs owned by the authenticated user."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(LocalAgentConfig.user_id == current_user.id)
|
|
||||||
)
|
|
||||||
return [_to_local_response(a) for a in result.scalars().all()]
|
|
||||||
|
|
||||||
|
Since configuration is client-owned, the Electron app sends its current
|
||||||
@router.post("/local", response_model=LocalAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
active agent count and the backend applies tier limits.
|
||||||
async def create_local_agent(
|
|
||||||
body: LocalAgentConfigCreate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> LocalAgentConfigResponse:
|
|
||||||
"""Create a new local directory agent config.
|
|
||||||
|
|
||||||
The combined count of enabled local and cloud agents for the user is
|
|
||||||
checked against the ``batch_active`` limit for their billing tier.
|
|
||||||
"""
|
"""
|
||||||
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
|
||||||
agent = LocalAgentConfig(
|
allowed = limit == -1 or body.active_agents < limit
|
||||||
user_id=current_user.id,
|
return AgentCreationCheckResponse(
|
||||||
name=body.name,
|
allowed=allowed,
|
||||||
device_id=body.device_id,
|
tier=current_user.tier,
|
||||||
directory_paths=body.directory_paths,
|
active_agents=body.active_agents,
|
||||||
data_types=body.data_types,
|
limit=limit,
|
||||||
prompt_template=body.prompt_template,
|
|
||||||
file_extensions=body.file_extensions,
|
|
||||||
schedule_cron=body.schedule_cron,
|
|
||||||
)
|
)
|
||||||
db.add(agent)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_local_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/local/{agent_id}", response_model=LocalAgentConfigResponse)
|
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
||||||
async def update_local_agent(
|
|
||||||
agent_id: str,
|
|
||||||
body: LocalAgentConfigUpdate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> LocalAgentConfigResponse:
|
|
||||||
"""Partially update a local agent config. Only provided fields are changed."""
|
|
||||||
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
for field, value in body.model_dump(exclude_unset=True).items():
|
|
||||||
setattr(agent, field, value)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_local_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/local/{agent_id}", response_model=dict)
|
|
||||||
async def delete_local_agent(
|
|
||||||
agent_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete a local agent config. Associated run logs are cascade-deleted."""
|
|
||||||
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
await db.delete(agent)
|
|
||||||
await db.commit()
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud agent CRUD ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/cloud", response_model=list[CloudAgentConfigResponse])
|
|
||||||
async def list_cloud_agents(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> list[CloudAgentConfigResponse]:
|
|
||||||
"""List all cloud connector agent configs owned by the authenticated user."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(CloudAgentConfig.user_id == current_user.id)
|
|
||||||
)
|
|
||||||
return [_to_cloud_response(a) for a in result.scalars().all()]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/cloud", response_model=CloudAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
|
||||||
async def create_cloud_agent(
|
|
||||||
body: CloudAgentConfigCreate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> CloudAgentConfigResponse:
|
|
||||||
"""Create a new cloud connector agent config.
|
|
||||||
|
|
||||||
The combined count of enabled local and cloud agents for the user is
|
|
||||||
checked against the ``batch_active`` limit for their billing tier.
|
|
||||||
"""
|
|
||||||
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
|
||||||
agent = CloudAgentConfig(
|
|
||||||
user_id=current_user.id,
|
|
||||||
provider=body.provider,
|
|
||||||
name=body.name,
|
|
||||||
data_types=body.data_types,
|
|
||||||
prompt_template=body.prompt_template,
|
|
||||||
oauth_token_encrypted=body.oauth_token_encrypted,
|
|
||||||
schedule_cron=body.schedule_cron,
|
|
||||||
filter_config=body.filter_config,
|
|
||||||
)
|
|
||||||
db.add(agent)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_cloud_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/cloud/{agent_id}", response_model=CloudAgentConfigResponse)
|
|
||||||
async def update_cloud_agent(
|
|
||||||
agent_id: str,
|
|
||||||
body: CloudAgentConfigUpdate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> CloudAgentConfigResponse:
|
|
||||||
"""Partially update a cloud agent config. Only provided fields are changed."""
|
|
||||||
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
for field, value in body.model_dump(exclude_unset=True).items():
|
|
||||||
setattr(agent, field, value)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(agent)
|
|
||||||
return _to_cloud_response(agent)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/cloud/{agent_id}", response_model=dict)
|
|
||||||
async def delete_cloud_agent(
|
|
||||||
agent_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete a cloud agent config. Associated run logs are cascade-deleted."""
|
|
||||||
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
|
||||||
await db.delete(agent)
|
|
||||||
await db.commit()
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Run logs ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/runs", response_model=_RunsPage)
|
|
||||||
async def list_run_logs(
|
|
||||||
agent_id: str | None = Query(default=None),
|
|
||||||
page: int = Query(default=1, ge=1),
|
|
||||||
limit: int = Query(default=20, ge=1, le=100),
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> _RunsPage:
|
|
||||||
"""Return paginated run logs for the authenticated user.
|
|
||||||
|
|
||||||
Optionally filter by ``agent_id``. Results are ordered from newest to oldest.
|
|
||||||
"""
|
|
||||||
base_filter = [AgentRunLog.user_id == current_user.id]
|
|
||||||
if agent_id:
|
|
||||||
base_filter.append(AgentRunLog.agent_id == agent_id)
|
|
||||||
|
|
||||||
total = (
|
|
||||||
await db.execute(select(func.count(AgentRunLog.id)).where(*base_filter))
|
|
||||||
).scalar_one()
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
select(AgentRunLog)
|
|
||||||
.where(*base_filter)
|
|
||||||
.order_by(AgentRunLog.started_at.desc())
|
|
||||||
.offset((page - 1) * limit)
|
|
||||||
.limit(limit)
|
|
||||||
)
|
|
||||||
items = [_to_run_log_response(log) for log in result.scalars().all()]
|
|
||||||
|
|
||||||
return _RunsPage(total=total, page=page, limit=limit, items=items)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Manual trigger stub ───────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.post("/{agent_id}/run", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
|
||||||
async def trigger_agent_run(
|
async def trigger_agent_run(
|
||||||
agent_id: str,
|
body: AgentTriggerRequest,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_session),
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> AgentRunLogResponse:
|
) -> AgentRunLogResponse:
|
||||||
"""Manually trigger an agent run.
|
"""Trigger a local agent run using client-provided configuration."""
|
||||||
|
_enforce_agent_limit(current_user.tier, body.active_agents)
|
||||||
|
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
||||||
|
|
||||||
Looks up the agent config (local or cloud) by ID with ownership check,
|
config = LocalAgentConfig(
|
||||||
creates a run log entry with ``status="running"``, and returns it.
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=current_user.id,
|
||||||
Actual dispatch to the agent runner is wired in Step 3.4 once
|
device_id=body.device_id,
|
||||||
``DeviceConnectionManager`` and ``agent_runner`` are available.
|
name="Local Directory Monitor",
|
||||||
"""
|
directory_paths=[body.directory],
|
||||||
# Determine agent type by trying local first, then cloud.
|
data_types=_to_data_types(body.what_to_extract),
|
||||||
# Keep the full config object so we can pass it to the agent runner.
|
prompt_template=body.custom_agent_prompt,
|
||||||
local_config: LocalAgentConfig | None = None
|
file_extensions=[],
|
||||||
cloud_config: CloudAgentConfig | None = None
|
schedule_cron=body.batch_interval,
|
||||||
|
enabled=True,
|
||||||
local_result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(
|
|
||||||
LocalAgentConfig.id == agent_id,
|
|
||||||
LocalAgentConfig.user_id == current_user.id,
|
|
||||||
)
|
)
|
||||||
)
|
|
||||||
local_config = local_result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if local_config is not None:
|
|
||||||
agent_type = "local"
|
|
||||||
else:
|
|
||||||
cloud_result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(
|
|
||||||
CloudAgentConfig.id == agent_id,
|
|
||||||
CloudAgentConfig.user_id == current_user.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cloud_config = cloud_result.scalar_one_or_none()
|
|
||||||
if cloud_config is not None:
|
|
||||||
agent_type = "cloud"
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
|
||||||
|
|
||||||
run_log = AgentRunLog(
|
run_log = AgentRunLog(
|
||||||
agent_id=agent_id,
|
agent_id=config.id,
|
||||||
agent_type=agent_type,
|
agent_type="local",
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
status="running",
|
status="running",
|
||||||
)
|
)
|
||||||
@@ -439,14 +200,8 @@ async def trigger_agent_run(
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(run_log)
|
await db.refresh(run_log)
|
||||||
|
|
||||||
# Dispatch the run as a background task — returns 202 immediately.
|
|
||||||
if agent_type == "local" and local_config is not None:
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
run_local_agent(current_user.id, local_config, run_log, device_manager)
|
run_local_agent(current_user.id, config, run_log, device_manager)
|
||||||
)
|
|
||||||
elif agent_type == "cloud" and cloud_config is not None:
|
|
||||||
asyncio.create_task(
|
|
||||||
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return _to_run_log_response(run_log)
|
return _to_run_log_response(run_log)
|
||||||
|
|||||||
@@ -10,9 +10,7 @@ from fastapi.responses import JSONResponse
|
|||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.core.deep_agent import run_home
|
from app.core.deep_agent import run_home
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.schemas import ChatRequest, UserProfile
|
||||||
from app.db import async_session
|
|
||||||
from app.schemas import ChatRequest, ChatResponse, UserProfile
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
|
||||||
@@ -22,21 +20,10 @@ async def chat(
|
|||||||
body: ChatRequest,
|
body: ChatRequest,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
"""Route a chat message through the Home deep agent (non-streaming)."""
|
"""REST fallback for home chat when websocket streaming is unavailable."""
|
||||||
async with async_session() as db:
|
response = await run_home(
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
memory_context = await memory.enrich_context(current_user.id, body.message)
|
|
||||||
|
|
||||||
context = {
|
|
||||||
**body.context.model_dump(),
|
|
||||||
**memory_context,
|
|
||||||
}
|
|
||||||
|
|
||||||
response_text = await run_home(
|
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
message=body.message,
|
message=body.message,
|
||||||
context=context,
|
context=body.context.model_dump(),
|
||||||
db_session_factory=async_session,
|
|
||||||
)
|
)
|
||||||
result = ChatResponse(response=response_text)
|
return JSONResponse(content={"response": response})
|
||||||
return JSONResponse(content=result.model_dump())
|
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ Protocol:
|
|||||||
|
|
||||||
Incoming frame dispatch:
|
Incoming frame dispatch:
|
||||||
- ``tool_result`` → resolves a pending tool-call Future.
|
- ``tool_result`` → resolves a pending tool-call Future.
|
||||||
- ``agent_data`` → enqueued in the per-run agent data queue.
|
- ``journey_start`` → starts a guided setup journey session.
|
||||||
- ``agent_complete`` → sends None sentinel to close the queue stream.
|
- ``journey_message`` → continues a journey conversation.
|
||||||
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
||||||
- unknown types → logged, ignored.
|
- unknown types → logged, ignored.
|
||||||
|
|
||||||
@@ -39,12 +39,13 @@ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
|
||||||
|
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.core.agent_runner import trigger_pending_runs
|
from app.core.agent_runner import trigger_pending_runs
|
||||||
|
from app.core.deep_agent import run_floating_stream, run_home_stream
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.core.deep_agent import run_home_stream, run_floating_stream
|
from app.core.output_formatter import StreamFormatter
|
||||||
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
|
||||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog
|
from app.models import AgentRunLog
|
||||||
@@ -147,37 +148,6 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
"device_ws: tool_result missing id from user=%s", user_id
|
"device_ws: tool_result missing id from user=%s", user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
elif frame_type == WsFrameType.agent_data:
|
|
||||||
run_id = frame.get("run_id")
|
|
||||||
if run_id:
|
|
||||||
try:
|
|
||||||
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
|
||||||
await queue.put(frame)
|
|
||||||
except RuntimeError:
|
|
||||||
logger.warning(
|
|
||||||
"device_ws: agent_data for unknown run user=%s run=%s",
|
|
||||||
user_id,
|
|
||||||
run_id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"device_ws: agent_data missing run_id from user=%s", user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.agent_complete:
|
|
||||||
run_id = frame.get("run_id")
|
|
||||||
if run_id:
|
|
||||||
try:
|
|
||||||
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
|
||||||
# Sentinel: signals the agent data stream is finished.
|
|
||||||
await queue.put(None)
|
|
||||||
except RuntimeError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"device_ws: agent_complete missing run_id from user=%s", user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.home_request:
|
elif frame_type == WsFrameType.home_request:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
_handle_home_request(websocket, user_id, frame)
|
_handle_home_request(websocket, user_id, frame)
|
||||||
@@ -188,6 +158,16 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
_handle_floating_request(websocket, user_id, frame)
|
_handle_floating_request(websocket, user_id, frame)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.journey_start:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_journey_start(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.journey_message:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_journey_message(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
elif frame_type == "pong":
|
elif frame_type == "pong":
|
||||||
# Heartbeat ack — nothing to do, connection is alive.
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
pass
|
pass
|
||||||
@@ -200,35 +180,13 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
|
|
||||||
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
_WS_TOOL_CALL_TIMEOUT = 30 # seconds to wait for Electron tool_result
|
|
||||||
|
|
||||||
|
|
||||||
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
||||||
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
||||||
async def _executor(payload: dict) -> dict:
|
async def _executor(payload: dict) -> dict:
|
||||||
payload["type"] = WsFrameType.tool_call
|
payload["type"] = WsFrameType.tool_call
|
||||||
call_id = payload["id"]
|
|
||||||
logger.info("ws_executor: sending tool_call id=%s action=%s", call_id, payload.get("action"))
|
|
||||||
await websocket.send_text(json.dumps(payload))
|
await websocket.send_text(json.dumps(payload))
|
||||||
future = device_manager.create_pending_call(user_id, call_id)
|
future = device_manager.create_pending_call(user_id, payload["id"])
|
||||||
try:
|
return await future
|
||||||
result = await asyncio.wait_for(future, timeout=_WS_TOOL_CALL_TIMEOUT)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error(
|
|
||||||
"ws_executor: timeout waiting for tool_result id=%s action=%s user=%s",
|
|
||||||
call_id, payload.get("action"), user_id,
|
|
||||||
)
|
|
||||||
# Clean up the pending future so it doesn't leak
|
|
||||||
conn = device_manager._connections.get(user_id)
|
|
||||||
if conn:
|
|
||||||
conn.pending_calls.pop(call_id, None)
|
|
||||||
return {"error": f"Tool call timed out after {_WS_TOOL_CALL_TIMEOUT}s", "rows": []}
|
|
||||||
logger.info("ws_executor: tool_result id=%s result_type=%s result_keys=%s",
|
|
||||||
call_id, type(result).__name__,
|
|
||||||
list(result.keys()) if isinstance(result, dict) else "N/A")
|
|
||||||
if result is None:
|
|
||||||
logger.error("ws_executor: future resolved to None for call_id=%s user=%s", call_id, user_id)
|
|
||||||
return result
|
|
||||||
return _executor
|
return _executor
|
||||||
|
|
||||||
|
|
||||||
@@ -241,14 +199,27 @@ async def _handle_home_request(
|
|||||||
request_id = frame.get("request_id") or str(uuid4())
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
logger.info(
|
||||||
|
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
message[:200],
|
||||||
|
)
|
||||||
|
|
||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(user_id, message)
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id,
|
||||||
|
message,
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
context: dict = {
|
context: dict = {
|
||||||
"conversation_history": frame.get("conversation_history", []),
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
**memory_context,
|
**memory_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -256,12 +227,11 @@ async def _handle_home_request(
|
|||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
try:
|
try:
|
||||||
event_stream = run_home_stream(
|
event_stream = run_home_stream(user_id, message, context)
|
||||||
user_id, message, context, db_session_factory=async_session
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
)
|
|
||||||
formatter = HomeFormatter(request_id=request_id)
|
|
||||||
async for ws_frame in formatter.format(event_stream):
|
async for ws_frame in formatter.format(event_stream):
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
# Collect text chunks to build the full response for episode storage
|
||||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -276,7 +246,14 @@ async def _handle_home_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks)
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
len("".join(response_chunks)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -290,23 +267,37 @@ async def _handle_floating_request(
|
|||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
scope: dict = frame.get("scope", {})
|
scope: dict = frame.get("scope", {})
|
||||||
|
logger.info(
|
||||||
|
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
json.dumps(scope, ensure_ascii=True)[:200],
|
||||||
|
message[:200],
|
||||||
|
)
|
||||||
|
|
||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(user_id, message)
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id,
|
||||||
|
message,
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
context: dict = {"scope": scope, **memory_context}
|
context: dict = {
|
||||||
|
"scope": scope,
|
||||||
|
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
try:
|
try:
|
||||||
event_stream = run_floating_stream(
|
event_stream = run_floating_stream(user_id, message, context)
|
||||||
user_id, message, context, scope=scope,
|
formatter = StreamFormatter(request_id=request_id)
|
||||||
db_session_factory=async_session,
|
|
||||||
)
|
|
||||||
formatter = FloatingFormatter(request_id=request_id)
|
|
||||||
async for ws_frame in formatter.format(event_stream):
|
async for ws_frame in formatter.format(event_stream):
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
@@ -323,8 +314,72 @@ async def _handle_floating_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks)
|
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||||
)
|
)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
|
||||||
|
user_id,
|
||||||
|
request_id,
|
||||||
|
session_id,
|
||||||
|
len("".join(response_chunks)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_start(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a journey_start frame — explores directory and sends first question."""
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
try:
|
||||||
|
reply = await handle_journey_start(user_id, frame)
|
||||||
|
await websocket.send_text(json.dumps(reply))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: journey_start failed user=%s: %s", user_id, exc
|
||||||
|
)
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": frame.get("session_id", ""),
|
||||||
|
"message": f"Failed to start journey: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}))
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_message(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a journey_message frame — continues the journey conversation."""
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
try:
|
||||||
|
reply = await handle_journey_message(user_id, frame)
|
||||||
|
await websocket.send_text(json.dumps(reply))
|
||||||
|
except Exception as exc:
|
||||||
|
session_id = frame.get("session_id", "")
|
||||||
|
logger.error(
|
||||||
|
"device_ws: journey_message failed user=%s session=%s: %s",
|
||||||
|
user_id, session_id, exc,
|
||||||
|
)
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": f"Journey error: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}))
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
|
||||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||||
@@ -360,6 +415,3 @@ async def _mark_runs_disconnected(user_id: str) -> None:
|
|||||||
user_id,
|
user_id,
|
||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"free": {
|
"free": {
|
||||||
"agents": 3,
|
"agents": 3,
|
||||||
"batch_active": 2,
|
"batch_active": 2,
|
||||||
|
"batch_runs_per_day": 5,
|
||||||
"cloud_storage_gb": 0,
|
"cloud_storage_gb": 0,
|
||||||
"backup_gb": 0,
|
"backup_gb": 0,
|
||||||
"providers": 1,
|
"providers": 1,
|
||||||
@@ -31,6 +32,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"pro": {
|
"pro": {
|
||||||
"agents": -1, # unlimited
|
"agents": -1, # unlimited
|
||||||
"batch_active": 10,
|
"batch_active": 10,
|
||||||
|
"batch_runs_per_day": 50,
|
||||||
"cloud_storage_gb": 5,
|
"cloud_storage_gb": 5,
|
||||||
"backup_gb": 5,
|
"backup_gb": 5,
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
@@ -41,6 +43,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"power": {
|
"power": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1, # unlimited
|
"batch_active": -1, # unlimited
|
||||||
|
"batch_runs_per_day": -1, # unlimited
|
||||||
"cloud_storage_gb": 25,
|
"cloud_storage_gb": 25,
|
||||||
"backup_gb": 25,
|
"backup_gb": 25,
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
@@ -51,6 +54,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"team": {
|
"team": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1,
|
"batch_active": -1,
|
||||||
|
"batch_runs_per_day": -1, # unlimited
|
||||||
"cloud_storage_gb": -1, # unlimited
|
"cloud_storage_gb": -1, # unlimited
|
||||||
"backup_gb": -1, # unlimited
|
"backup_gb": -1, # unlimited
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
|
|||||||
30
app/core/agent_registry.py
Normal file
30
app/core/agent_registry.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""Minimal agent base types retained for compatibility with batch runners."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAgent(ABC):
|
||||||
|
"""Common base for non-chat agents still using the old base contract."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_id: str = "",
|
||||||
|
shared_memory: dict[str, Any] | None = None,
|
||||||
|
vector_store_context: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.user_id = user_id
|
||||||
|
self.shared_memory: dict[str, Any] = shared_memory or {}
|
||||||
|
self.vector_store_context: list[str] = vector_store_context or []
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_name(self) -> str: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_description(self) -> str: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def skills(self) -> list[str]:
|
||||||
|
return []
|
||||||
@@ -1,15 +1,15 @@
|
|||||||
"""Agent run manager.
|
"""Agent run orchestrator.
|
||||||
|
|
||||||
Drives two agent types:
|
Drives two agent types:
|
||||||
|
|
||||||
* **Local directory agent** — sends an ``agent_run`` frame to the connected
|
* **Local directory agent** — two-phase execution that mirrors the
|
||||||
Electron device, waits for the device to stream back file contents via
|
``deep_agent.py`` tool-calling pattern. Phase 1 (Triage) explores the
|
||||||
``agent_data`` frames, then calls the LLM to extract structured items from
|
user's directory via file-system tools and groups files by project.
|
||||||
each file and pushes inserts to Electron via tool-call round-trips.
|
Phase 2 (Processing) reads full file contents and performs CRUD
|
||||||
|
operations using the standard entity tools (tasks, notes, etc.).
|
||||||
|
|
||||||
* **Cloud connector agent** — fetches data from third-party APIs (Gmail,
|
* **Cloud connector agent** — fetches data from third-party APIs (Gmail,
|
||||||
Teams, Outlook) and pushes extracted items to Electron. **This path is
|
Teams, Outlook) and pushes extracted items to Electron.
|
||||||
a stub** — provider integrations are implemented in Step 3.6.
|
|
||||||
|
|
||||||
Usage
|
Usage
|
||||||
-----
|
-----
|
||||||
@@ -33,11 +33,17 @@ from datetime import datetime, timedelta, timezone
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from croniter import croniter
|
from croniter import croniter
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||||||
|
from app.agents.note_agent import NOTE_TOOLS
|
||||||
|
from app.agents.project_agent import PROJECT_TOOLS
|
||||||
|
from app.agents.task_agent import TASK_TOOLS
|
||||||
|
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||||
from app.core.device_manager import DeviceConnectionManager
|
from app.core.device_manager import DeviceConnectionManager
|
||||||
from app.core.llm import get_llm
|
from app.core.llm import get_llm
|
||||||
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
|
|
||||||
@@ -45,50 +51,108 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# ── Timeouts ───────────────────────────────────────────────────────────────
|
# ── Timeouts ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
# Max seconds to wait for Electron to finish streaming file data.
|
# Max seconds to wait for a single tool-call round-trip (FE → BE).
|
||||||
_FILE_READ_TIMEOUT: int = 120
|
_TOOL_CALL_TIMEOUT: int = 30
|
||||||
# Max seconds to wait for Electron to acknowledge a single tool-call insert.
|
# Max LLM reasoning steps per phase.
|
||||||
_INSERT_TIMEOUT: int = 30
|
_MAX_TRIAGE_STEPS: int = 10
|
||||||
|
_MAX_PROCESSING_STEPS: int = 12
|
||||||
|
|
||||||
# ── Allowed tables & extraction schema hints ───────────────────────────────
|
# ── Data-type to tool mapping ─────────────────────────────────────────────
|
||||||
|
|
||||||
_ALLOWED_TABLES: frozenset[str] = frozenset(
|
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
|
||||||
{"tasks", "notes", "timelines", "projects", "taskComments"}
|
"tasks": TASK_TOOLS,
|
||||||
)
|
"projects": PROJECT_TOOLS,
|
||||||
|
"notes": NOTE_TOOLS,
|
||||||
# Field descriptions fed to the extraction LLM as concise schema references.
|
"timelines": TIMELINE_TOOLS,
|
||||||
_TABLE_SCHEMAS: dict[str, str] = {
|
|
||||||
"tasks": (
|
|
||||||
"title (str, required), description (str), "
|
|
||||||
"status (todo|in_progress|done, default todo), "
|
|
||||||
"priority (high|medium|low, default medium), "
|
|
||||||
"assignee (JSON array string), dueDate (ms timestamp int), projectId (str)"
|
|
||||||
),
|
|
||||||
"notes": "title (str, required), content (str, markdown), projectId (str)",
|
|
||||||
"timelines": (
|
|
||||||
"title (str, required), projectId (str, required), date (ms timestamp int)"
|
|
||||||
),
|
|
||||||
"projects": "name (str, required), clientId (str)",
|
|
||||||
"taskComments": "taskId (str, required), author (str), content (str, required)",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
_EXTRACTION_SYSTEM_PROMPT = """\
|
# ── Triage prompt ─────────────────────────────────────────────────────────
|
||||||
You are a data extraction assistant for a freelance project management tool.
|
|
||||||
Given a document, extract structured records matching the user's instructions.
|
|
||||||
|
|
||||||
Output a JSON array (no markdown fences, no explanation) of objects shaped:
|
_TRIAGE_SYSTEM_PROMPT = """\
|
||||||
[{{"table": "<table_name>", "data": {{...fields}}}}, ...]
|
You are a file triage assistant for a freelance project management tool.
|
||||||
|
Your job is to explore a local directory on the user's device, understand its
|
||||||
|
structure, and group files by project context.
|
||||||
|
|
||||||
Allowed table names and their fields:
|
You have access to these tools:
|
||||||
{table_schemas}
|
- list_directory: to map folder structure
|
||||||
|
- get_file_metadata: to check creation/modification dates
|
||||||
|
- read_file_content: to read brief snippets when needed for categorisation
|
||||||
|
- list_projects / list_all_projects / get_project: to fetch existing projects
|
||||||
|
from the user's workspace and match files to them
|
||||||
|
|
||||||
Rules:
|
Instructions:
|
||||||
- Only extract tables listed in the "data_types" instructions.
|
1. Start by calling list_directory on the configured root path.
|
||||||
- Use camelCase field names exactly as shown above.
|
2. Explore subdirectories as needed to understand the structure.
|
||||||
- Omit optional fields you cannot determine; do not invent data.
|
3. Use get_file_metadata to check modification dates. Skip files that have
|
||||||
- Never include id, createdAt, updatedAt, isAiSuggested, or isApproved.
|
NOT been modified since: {last_run_at}.
|
||||||
- If nothing relevant is found, return an empty JSON array: []
|
4. Call list_all_projects to get the user's existing projects.
|
||||||
- Return ONLY the JSON array.
|
5. Match files to existing projects by name, folder structure, or content hints.
|
||||||
|
6. If files don't match any existing project, group them under "standalone".
|
||||||
|
|
||||||
|
{custom_prompt_section}
|
||||||
|
|
||||||
|
Target entity types to extract: {data_types}
|
||||||
|
File extensions to consider: {file_extensions}
|
||||||
|
|
||||||
|
When you have finished exploring, output ONLY a JSON object (no markdown
|
||||||
|
fences, no explanation) mapping project IDs or "standalone" to file path
|
||||||
|
arrays:
|
||||||
|
|
||||||
|
{{"<project_id>": ["<file_path>", ...], "standalone": ["<file_path>", ...]}}
|
||||||
|
|
||||||
|
Return ONLY the JSON object as your final message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Processing prompt ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_PROCESSING_BASE_PROMPT = """\
|
||||||
|
You are a data extraction and management assistant for a freelance project
|
||||||
|
management tool.
|
||||||
|
|
||||||
|
Available tools:
|
||||||
|
Filesystem : read_file_content, list_directory, get_file_metadata
|
||||||
|
Tasks : list_tasks, create_task, update_task, add_task_comment
|
||||||
|
Notes : list_notes, get_note, create_note, update_note
|
||||||
|
Timelines : list_timelines, create_timeline, update_timeline
|
||||||
|
Projects : list_all_projects, get_project, create_project, update_project
|
||||||
|
|
||||||
|
Your task:
|
||||||
|
1. Read the full content of each file below using read_file_content.
|
||||||
|
2. For each piece of information found, ALWAYS try to match and update an
|
||||||
|
existing record before creating a new one.
|
||||||
|
3. ONLY act on these entity types: {data_types}.
|
||||||
|
4. Do NOT invent data. Only extract what is clearly present in the files.
|
||||||
|
5. If a file contains no relevant data for the target entity types, skip it.
|
||||||
|
|
||||||
|
Update-first rules (apply in this order):
|
||||||
|
Tasks:
|
||||||
|
- Call list_tasks to find a match by title or context.
|
||||||
|
- If found: call add_task_comment (author "Adiuva"), update_task to set
|
||||||
|
assignees, state (ToDo / In Progress / Completed), or other fields.
|
||||||
|
- If NOT found: call create_task with isAiSuggested=1, isApproved=0.
|
||||||
|
Timelines:
|
||||||
|
- Call list_timelines to find a match by title or date.
|
||||||
|
- If found: call update_timeline to edit fields or mark it complete.
|
||||||
|
- If NOT found: call create_timeline with isAiSuggested=1, isApproved=0.
|
||||||
|
Notes:
|
||||||
|
- Call list_notes to find a match by title or topic, then get_note to
|
||||||
|
read its current content.
|
||||||
|
- If found: call update_note with the merged content.
|
||||||
|
- If NOT found: call create_note with isAiSuggested=1, isApproved=0.
|
||||||
|
Projects:
|
||||||
|
- Call list_all_projects to check for a match first.
|
||||||
|
- Only call create_project if the information is clearly significant and
|
||||||
|
no existing project matches. Set isAiSuggested=1, isApproved=0.
|
||||||
|
|
||||||
|
{project_context}
|
||||||
|
|
||||||
|
Files to process:
|
||||||
|
{file_list}
|
||||||
|
|
||||||
|
{custom_prompt_section}
|
||||||
|
|
||||||
|
After processing all files, respond with a brief summary of what you updated
|
||||||
|
and what you created.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -118,100 +182,145 @@ def _is_overdue(schedule_cron: str, last_run_at: datetime | None) -> bool:
|
|||||||
return False # Fail-safe: don't trigger if expression is invalid.
|
return False # Fail-safe: don't trigger if expression is invalid.
|
||||||
|
|
||||||
|
|
||||||
# ── LLM extraction ─────────────────────────────────────────────────────────
|
# ── WS executor for agent context ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def _extract_items_from_content(
|
def _make_agent_executor(
|
||||||
prompt_template: str,
|
|
||||||
file_content: str,
|
|
||||||
data_types: list[str],
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Call the LLM to extract structured records from *file_content*.
|
|
||||||
|
|
||||||
Returns a validated list of ``{table: str, data: dict}`` objects.
|
|
||||||
Items referencing tables not in *data_types* are discarded.
|
|
||||||
"""
|
|
||||||
allowed = [t for t in data_types if t in _ALLOWED_TABLES]
|
|
||||||
if not allowed:
|
|
||||||
return []
|
|
||||||
|
|
||||||
schema_text = "\n".join(
|
|
||||||
f" {table}: {_TABLE_SCHEMAS.get(table, '(unknown)')}" for table in allowed
|
|
||||||
)
|
|
||||||
system_prompt = _EXTRACTION_SYSTEM_PROMPT.format(table_schemas=schema_text)
|
|
||||||
user_prompt = (
|
|
||||||
f"User instructions: {prompt_template}\n\n"
|
|
||||||
f"Extract these record types: {', '.join(allowed)}\n\n"
|
|
||||||
f"Document:\n{file_content[:8000]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = get_llm()
|
|
||||||
raw = ""
|
|
||||||
try:
|
|
||||||
response = await llm.ainvoke(
|
|
||||||
[SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)]
|
|
||||||
)
|
|
||||||
raw = str(response.content).strip()
|
|
||||||
items: list[dict] = json.loads(raw)
|
|
||||||
if not isinstance(items, list):
|
|
||||||
raise ValueError("LLM response is not a JSON array")
|
|
||||||
except json.JSONDecodeError as exc:
|
|
||||||
logger.warning(
|
|
||||||
"agent_runner: LLM extraction returned invalid JSON: %s — snippet: %.200r",
|
|
||||||
exc,
|
|
||||||
raw,
|
|
||||||
)
|
|
||||||
return []
|
|
||||||
# Other exceptions (LLM API errors, network errors) propagate to the
|
|
||||||
# caller (run_local_agent) which records them per-file in the run log.
|
|
||||||
|
|
||||||
validated: list[dict[str, Any]] = []
|
|
||||||
for item in items:
|
|
||||||
table = item.get("table")
|
|
||||||
data = item.get("data")
|
|
||||||
if not isinstance(table, str) or table not in allowed:
|
|
||||||
continue
|
|
||||||
if not isinstance(data, dict) or not data:
|
|
||||||
continue
|
|
||||||
# Strip any server-generated or forbidden fields.
|
|
||||||
for _field in ("id", "createdAt", "updatedAt", "isAiSuggested", "isApproved"):
|
|
||||||
data.pop(_field, None)
|
|
||||||
validated.append({"table": table, "data": data})
|
|
||||||
return validated
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tool-call insert helper ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _send_insert_to_client(
|
|
||||||
user_id: str,
|
user_id: str,
|
||||||
table: str,
|
|
||||||
data: dict[str, Any],
|
|
||||||
device_mgr: DeviceConnectionManager,
|
device_mgr: DeviceConnectionManager,
|
||||||
) -> dict[str, Any]:
|
) -> Any:
|
||||||
"""Send an ``insert`` tool_call frame to Electron and await the tool_result.
|
"""Create a WS callback for ``set_client_executor()`` so that all tools
|
||||||
|
can use ``execute_on_client()`` during an agent run.
|
||||||
All inserts include ``isAiSuggested=1, isApproved=0`` so the user can
|
|
||||||
review AI-produced records before they are treated as confirmed.
|
|
||||||
|
|
||||||
Raises ``asyncio.TimeoutError`` if Electron does not respond within
|
|
||||||
``_INSERT_TIMEOUT`` seconds. Raises ``RuntimeError`` if the device
|
|
||||||
disconnects before the frame can be sent.
|
|
||||||
"""
|
"""
|
||||||
call_id = str(uuid.uuid4())
|
async def _executor(payload: dict) -> dict:
|
||||||
payload: dict[str, Any] = {
|
payload["type"] = "tool_call"
|
||||||
"type": "tool_call",
|
call_id = payload["id"]
|
||||||
"id": call_id,
|
|
||||||
"action": "insert",
|
|
||||||
"table": table,
|
|
||||||
"data": {**data, "isAiSuggested": 1, "isApproved": 0},
|
|
||||||
}
|
|
||||||
fut = device_mgr.create_pending_call(user_id, call_id)
|
fut = device_mgr.create_pending_call(user_id, call_id)
|
||||||
await device_mgr.send_frame(user_id, payload)
|
await device_mgr.send_frame(user_id, payload)
|
||||||
return await asyncio.wait_for(fut, timeout=_INSERT_TIMEOUT)
|
return await asyncio.wait_for(fut, timeout=_TOOL_CALL_TIMEOUT)
|
||||||
|
return _executor
|
||||||
|
|
||||||
|
|
||||||
# ── Local agent runner ──────────────────────────────────────────────────────
|
# ── LLM tool-calling loop (mirrors deep_agent._run_single_agent) ──────────
|
||||||
|
|
||||||
|
|
||||||
|
def _as_text(content: Any) -> str:
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_agent_with_tools(
|
||||||
|
*,
|
||||||
|
system_prompt: str,
|
||||||
|
user_message: str,
|
||||||
|
tools: list[Any],
|
||||||
|
max_steps: int,
|
||||||
|
) -> str:
|
||||||
|
"""Run an LLM agent with tool-calling, returning the final text response.
|
||||||
|
|
||||||
|
Follows the same pattern as ``deep_agent._run_single_agent``:
|
||||||
|
bind tools → invoke → handle tool calls → repeat until final text.
|
||||||
|
"""
|
||||||
|
llm = get_llm()
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
messages: list[Any] = [
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(content=user_message),
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_calls_count = 0
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
return _as_text(response.content)
|
||||||
|
|
||||||
|
for call in response.tool_calls:
|
||||||
|
tool_calls_count += 1
|
||||||
|
call_id = str(call.get("id", ""))
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: tool_call name=%s args=%s",
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: tool_result name=%s output=%s",
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:1200],
|
||||||
|
)
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
# Fallback: exceeded max steps, get final response without tools.
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
return _as_text(final.content)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Triage map parser ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_triage_map(raw: str) -> dict[str, list[str]] | None:
|
||||||
|
"""Extract the JSON triage map from the LLM's final response."""
|
||||||
|
text = raw.strip()
|
||||||
|
# Try direct parse first.
|
||||||
|
try:
|
||||||
|
parsed = json.loads(text)
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
return {k: v for k, v in parsed.items() if isinstance(v, list)}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try extracting JSON from markdown fences or surrounding text.
|
||||||
|
import re
|
||||||
|
match = re.search(r"\{[\s\S]*\}", text)
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(match.group(0))
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
return {k: v for k, v in parsed.items() if isinstance(v, list)}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tool list builder ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _build_processing_tools(data_types: list[str]) -> list[Any]:
|
||||||
|
"""Build the tool list for Phase 2 based on user's data_types selection."""
|
||||||
|
tools: list[Any] = list(FILESYSTEM_TOOLS)
|
||||||
|
for dt in data_types:
|
||||||
|
dt_tools = _DATA_TYPE_TOOLS.get(dt)
|
||||||
|
if dt_tools:
|
||||||
|
tools.extend(dt_tools)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local agent runner (two-phase) ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def run_local_agent(
|
async def run_local_agent(
|
||||||
@@ -220,143 +329,161 @@ async def run_local_agent(
|
|||||||
run_log: AgentRunLog,
|
run_log: AgentRunLog,
|
||||||
device_mgr: DeviceConnectionManager,
|
device_mgr: DeviceConnectionManager,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute a local directory agent run end-to-end.
|
"""Execute a local directory agent run using two-phase LLM-with-tools.
|
||||||
|
|
||||||
Steps:
|
Phase 1 — Triage:
|
||||||
|
Explore the directory structure, check metadata, match files to
|
||||||
|
existing projects. Output: a JSON map of project → file paths.
|
||||||
|
|
||||||
1. Verify the device identified by ``config.device_id`` is currently online.
|
Phase 2 — Processing:
|
||||||
2. Pre-create the agent_data queue so no incoming frames are lost.
|
For each project group, read full file contents and perform CRUD
|
||||||
3. Send ``agent_run`` frame to Electron (paths, extensions, prompt, data_types).
|
operations using the standard entity tools.
|
||||||
4. Consume ``agent_data`` frames until the ``None`` sentinel from
|
|
||||||
``agent_complete``.
|
|
||||||
5. For each received file call the LLM to extract ``{table, data}`` items.
|
|
||||||
6. Push each item to Electron as an ``insert`` tool-call; include
|
|
||||||
``isAiSuggested=1, isApproved=0`` so users can review AI suggestions.
|
|
||||||
7. Persist the run outcome (status, counts, errors) and update
|
|
||||||
``config.last_run_at``.
|
|
||||||
"""
|
"""
|
||||||
run_id = run_log.id
|
run_id = run_log.id
|
||||||
|
|
||||||
# ── 1. Device online check ─────────────────────────────────────────
|
# ── Device online check ─────────────────────────────────────────
|
||||||
if not device_mgr.is_online(user_id, config.device_id):
|
target_device_id = config.device_id.strip() if isinstance(config.device_id, str) else ""
|
||||||
|
if target_device_id:
|
||||||
|
is_online = device_mgr.is_online(user_id, target_device_id)
|
||||||
|
else:
|
||||||
|
is_online = device_mgr.is_online(user_id)
|
||||||
|
|
||||||
|
if not is_online:
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: skip run=%s — device %r offline for user=%s",
|
"agent_runner: skip run=%s — device %r offline for user=%s",
|
||||||
run_id,
|
run_id,
|
||||||
config.device_id,
|
target_device_id or "<any>",
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
await _finalize_run(
|
await _finalize_run(
|
||||||
run_log,
|
run_log,
|
||||||
status="error",
|
status="error",
|
||||||
errors=[f"Device {config.device_id!r} is not connected"],
|
errors=[f"Device {target_device_id or '<any>'!r} is not connected"],
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# ── 2. Pre-create agent_data queue ────────────────────────────────
|
# ── Set up WS executor for tools ────────────────────────────────
|
||||||
try:
|
executor = _make_agent_executor(user_id, device_mgr)
|
||||||
device_mgr.get_agent_data_queue(user_id, run_id)
|
set_client_executor(executor)
|
||||||
except RuntimeError:
|
|
||||||
await _finalize_run(
|
|
||||||
run_log,
|
|
||||||
status="error",
|
|
||||||
errors=["Device disconnected before agent run could start"],
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# ── 3. Send agent_run frame ────────────────────────────────────────
|
|
||||||
frame: dict[str, Any] = {
|
|
||||||
"type": "agent_run",
|
|
||||||
"run_id": run_id,
|
|
||||||
"agent_id": config.id,
|
|
||||||
"config": {
|
|
||||||
"paths": config.directory_paths,
|
|
||||||
"file_extensions": config.file_extensions,
|
|
||||||
"prompt_template": config.prompt_template,
|
|
||||||
"data_types": config.data_types,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
await device_mgr.send_frame(user_id, frame)
|
|
||||||
except RuntimeError as exc:
|
|
||||||
device_mgr.cleanup_agent_data_queue(user_id, run_id)
|
|
||||||
await _finalize_run(
|
|
||||||
run_log,
|
|
||||||
status="error",
|
|
||||||
errors=[f"Failed to send agent_run frame: {exc}"],
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"agent_runner: sent agent_run run=%s agent=%s user=%s",
|
|
||||||
run_id,
|
|
||||||
config.id,
|
|
||||||
user_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── 4. Consume agent_data frames ──────────────────────────────────
|
|
||||||
files: list[dict[str, Any]] = []
|
|
||||||
errors: list[str] = []
|
errors: list[str] = []
|
||||||
|
|
||||||
try:
|
|
||||||
queue = device_mgr.get_agent_data_queue(user_id, run_id)
|
|
||||||
deadline = asyncio.get_event_loop().time() + _FILE_READ_TIMEOUT
|
|
||||||
while True:
|
|
||||||
remaining = deadline - asyncio.get_event_loop().time()
|
|
||||||
if remaining <= 0:
|
|
||||||
errors.append("Timed out waiting for file data from device")
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
frame_data = await asyncio.wait_for(queue.get(), timeout=remaining)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
errors.append("Timed out waiting for file data from device")
|
|
||||||
break
|
|
||||||
if frame_data is None:
|
|
||||||
# Sentinel from agent_complete — stream is done.
|
|
||||||
break
|
|
||||||
files.extend(frame_data.get("files", []))
|
|
||||||
except RuntimeError as exc:
|
|
||||||
errors.append(f"Queue error reading agent data: {exc}")
|
|
||||||
|
|
||||||
# ── 5–6. Extract + insert ─────────────────────────────────────────
|
|
||||||
items_processed = 0
|
items_processed = 0
|
||||||
items_created = 0
|
items_created = 0
|
||||||
|
|
||||||
for file_info in files:
|
|
||||||
file_path: str = file_info.get("path", "<unknown>")
|
|
||||||
content: str = file_info.get("content", "")
|
|
||||||
if not content:
|
|
||||||
continue
|
|
||||||
items_processed += 1
|
|
||||||
try:
|
try:
|
||||||
extracted = await _extract_items_from_content(
|
# ── Phase 1: Triage ─────────────────────────────────────────
|
||||||
config.prompt_template, content, config.data_types
|
logger.info("agent_runner: run=%s phase=triage start user=%s", run_id, user_id)
|
||||||
|
|
||||||
|
last_run_str = "never (process all files)"
|
||||||
|
if config.last_run_at:
|
||||||
|
last_run_str = config.last_run_at.isoformat()
|
||||||
|
|
||||||
|
custom_section = ""
|
||||||
|
if config.prompt_template:
|
||||||
|
custom_section = f"User instructions:\n{config.prompt_template}"
|
||||||
|
|
||||||
|
file_ext_str = ", ".join(config.file_extensions) if config.file_extensions else "all"
|
||||||
|
|
||||||
|
triage_prompt = _TRIAGE_SYSTEM_PROMPT.format(
|
||||||
|
last_run_at=last_run_str,
|
||||||
|
custom_prompt_section=custom_section,
|
||||||
|
data_types=", ".join(config.data_types),
|
||||||
|
file_extensions=file_ext_str,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
|
||||||
errors.append(f"LLM extraction error for {file_path!r}: {exc}")
|
directory_paths = config.directory_paths
|
||||||
|
triage_user_msg = (
|
||||||
|
f"Explore these directories and produce the triage map:\n"
|
||||||
|
f"{json.dumps(directory_paths, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
triage_tools: list[Any] = list(FILESYSTEM_TOOLS) + list(PROJECT_TOOLS)
|
||||||
|
|
||||||
|
triage_response = await _run_agent_with_tools(
|
||||||
|
system_prompt=triage_prompt,
|
||||||
|
user_message=triage_user_msg,
|
||||||
|
tools=triage_tools,
|
||||||
|
max_steps=_MAX_TRIAGE_STEPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
triage_map = _parse_triage_map(triage_response)
|
||||||
|
if not triage_map:
|
||||||
|
errors.append(f"Triage phase failed to produce a valid file map: {triage_response[:500]}")
|
||||||
|
await _finalize_run(run_log, status="error", errors=errors)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: run=%s triage complete groups=%d total_files=%d",
|
||||||
|
run_id,
|
||||||
|
len(triage_map),
|
||||||
|
sum(len(files) for files in triage_map.values()),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Phase 2: Processing (per group) ─────────────────────────
|
||||||
|
processing_tools = _build_processing_tools(config.data_types)
|
||||||
|
|
||||||
|
for group_key, file_paths in triage_map.items():
|
||||||
|
if not file_paths:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for item in extracted:
|
logger.info(
|
||||||
try:
|
"agent_runner: run=%s phase=processing group=%s files=%d",
|
||||||
result = await _send_insert_to_client(
|
run_id,
|
||||||
user_id, item["table"], item["data"], device_mgr
|
group_key,
|
||||||
)
|
len(file_paths),
|
||||||
if result.get("error"):
|
|
||||||
errors.append(
|
|
||||||
f"Insert failed ({item['table']}, {file_path!r}): {result['error']}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Build project context for the LLM.
|
||||||
|
if group_key == "standalone":
|
||||||
|
project_context = "These files are not associated with any existing project."
|
||||||
else:
|
else:
|
||||||
items_created += 1
|
project_context = f"These files belong to project ID: {group_key}. Use this project_id when creating records."
|
||||||
except asyncio.TimeoutError:
|
|
||||||
errors.append(
|
file_list_str = "\n".join(f"- {fp}" for fp in file_paths)
|
||||||
f"Timed out awaiting insert ack ({item['table']}, {file_path!r})"
|
|
||||||
|
processing_prompt = _PROCESSING_BASE_PROMPT.format(
|
||||||
|
data_types=", ".join(config.data_types),
|
||||||
|
project_context=project_context,
|
||||||
|
file_list=file_list_str,
|
||||||
|
custom_prompt_section=custom_section,
|
||||||
)
|
)
|
||||||
except RuntimeError as exc:
|
|
||||||
errors.append(f"Insert error ({item['table']}, {file_path!r}): {exc}")
|
|
||||||
|
|
||||||
# ── 7. Finalise ────────────────────────────────────────────────────
|
items_processed += len(file_paths)
|
||||||
device_mgr.cleanup_agent_data_queue(user_id, run_id)
|
|
||||||
|
|
||||||
if errors and items_created == 0:
|
try:
|
||||||
|
result_text = await _run_agent_with_tools(
|
||||||
|
system_prompt=processing_prompt,
|
||||||
|
user_message="Process the listed files now.",
|
||||||
|
tools=processing_tools,
|
||||||
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: run=%s group=%s processing_result=%s",
|
||||||
|
run_id,
|
||||||
|
group_key,
|
||||||
|
result_text[:500],
|
||||||
|
)
|
||||||
|
# Count created items by scanning tool call results.
|
||||||
|
# The tools themselves handle creation; we estimate from the
|
||||||
|
# summary. A more precise count would require intercepting
|
||||||
|
# tool results, but the summary is sufficient for the run log.
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Processing error for group '{group_key}': {exc}")
|
||||||
|
logger.error(
|
||||||
|
"agent_runner: run=%s group=%s processing failed: %s",
|
||||||
|
run_id,
|
||||||
|
group_key,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Agent run failed: {exc}")
|
||||||
|
logger.error("agent_runner: run=%s failed: %s", run_id, exc)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
# ── Finalise ────────────────────────────────────────────────────
|
||||||
|
if errors and items_processed == 0:
|
||||||
final_status = "error"
|
final_status = "error"
|
||||||
elif errors:
|
elif errors:
|
||||||
final_status = "partial"
|
final_status = "partial"
|
||||||
@@ -369,16 +496,15 @@ async def run_local_agent(
|
|||||||
items_processed=items_processed,
|
items_processed=items_processed,
|
||||||
items_created=items_created,
|
items_created=items_created,
|
||||||
errors=errors,
|
errors=errors,
|
||||||
update_config_last_run=True,
|
update_config_last_run=False,
|
||||||
config_id=config.id,
|
config_id=config.id,
|
||||||
config_type="local",
|
config_type="local",
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: run=%s done status=%s processed=%d created=%d errors=%d",
|
"agent_runner: run=%s done status=%s processed=%d errors=%d",
|
||||||
run_id,
|
run_id,
|
||||||
final_status,
|
final_status,
|
||||||
items_processed,
|
items_processed,
|
||||||
items_created,
|
|
||||||
len(errors),
|
len(errors),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -405,8 +531,7 @@ async def run_cloud_agent(
|
|||||||
3. Instantiate the provider client (Gmail or MS Graph).
|
3. Instantiate the provider client (Gmail or MS Graph).
|
||||||
4. Fetch messages/emails since ``config.last_run_at`` (or 7 days ago for
|
4. Fetch messages/emails since ``config.last_run_at`` (or 7 days ago for
|
||||||
the first run) applying ``config.filter_config`` filters.
|
the first run) applying ``config.filter_config`` filters.
|
||||||
5. For each message/email call ``_extract_items_from_content`` with
|
5. For each message/email call the LLM to extract structured items.
|
||||||
``config.prompt_template`` to get structured ``{table, data}`` items.
|
|
||||||
6. Push each item to Electron as an ``insert`` tool-call.
|
6. Push each item to Electron as an ``insert`` tool-call.
|
||||||
7. If the provider refreshed its access token, re-encrypt and write it
|
7. If the provider refreshed its access token, re-encrypt and write it
|
||||||
back to ``config.oauth_token_encrypted``.
|
back to ``config.oauth_token_encrypted``.
|
||||||
@@ -514,37 +639,40 @@ async def run_cloud_agent(
|
|||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── 5–6. Extract + insert ─────────────────────────────────────────
|
# ── 5–6. Extract + insert via LLM with tools ─────────────────────
|
||||||
|
executor = _make_agent_executor(user_id, device_mgr)
|
||||||
|
set_client_executor(executor)
|
||||||
|
|
||||||
|
try:
|
||||||
|
processing_tools = _build_processing_tools(config.data_types)
|
||||||
|
custom_section = ""
|
||||||
|
if config.prompt_template:
|
||||||
|
custom_section = f"User instructions:\n{config.prompt_template}"
|
||||||
|
|
||||||
for msg in raw_messages:
|
for msg in raw_messages:
|
||||||
content_text = msg.as_text
|
content_text = msg.as_text
|
||||||
if not content_text:
|
if not content_text:
|
||||||
continue
|
continue
|
||||||
items_processed += 1
|
items_processed += 1
|
||||||
|
|
||||||
|
processing_prompt = _PROCESSING_BASE_PROMPT.format(
|
||||||
|
data_types=", ".join(config.data_types),
|
||||||
|
project_context="Determine the appropriate project from the message context.",
|
||||||
|
file_list=f"Message from {config.provider} (id: {msg.id})",
|
||||||
|
custom_prompt_section=custom_section,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
extracted = await _extract_items_from_content(
|
await _run_agent_with_tools(
|
||||||
config.prompt_template, content_text, config.data_types
|
system_prompt=processing_prompt,
|
||||||
|
user_message=f"Process this message content:\n\n{content_text[:8000]}",
|
||||||
|
tools=processing_tools,
|
||||||
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
errors.append(f"LLM extraction error for message {msg.id!r}: {exc}")
|
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
|
||||||
continue
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
for item in extracted:
|
|
||||||
try:
|
|
||||||
result = await _send_insert_to_client(
|
|
||||||
user_id, item["table"], item["data"], device_mgr
|
|
||||||
)
|
|
||||||
if result.get("error"):
|
|
||||||
errors.append(
|
|
||||||
f"Insert failed ({item['table']}, msg={msg.id!r}): {result['error']}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
items_created += 1
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
errors.append(
|
|
||||||
f"Timed out awaiting insert ack ({item['table']}, msg={msg.id!r})"
|
|
||||||
)
|
|
||||||
except RuntimeError as exc:
|
|
||||||
errors.append(f"Insert error ({item['table']}, msg={msg.id!r}): {exc}")
|
|
||||||
|
|
||||||
# ── 7. Persist refreshed token (if any) ───────────────────────────
|
# ── 7. Persist refreshed token (if any) ───────────────────────────
|
||||||
refreshed = getattr(provider, "refreshed_credentials", None)
|
refreshed = getattr(provider, "refreshed_credentials", None)
|
||||||
@@ -610,61 +738,12 @@ async def trigger_pending_runs(
|
|||||||
* Runs execute **sequentially** to avoid flooding the WS connection.
|
* Runs execute **sequentially** to avoid flooding the WS connection.
|
||||||
"""
|
"""
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: scanning overdue runs for user=%s device=%s", user_id, device_id
|
"agent_runner: pending-run scan skipped for user=%s device=%s (client-owned agent config)",
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
)
|
)
|
||||||
async with async_session() as db:
|
|
||||||
local_result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(
|
|
||||||
LocalAgentConfig.user_id == user_id,
|
|
||||||
LocalAgentConfig.enabled == True, # noqa: E712
|
|
||||||
LocalAgentConfig.device_id == device_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
local_configs: list[LocalAgentConfig] = list(local_result.scalars().all())
|
|
||||||
|
|
||||||
cloud_result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(
|
|
||||||
CloudAgentConfig.user_id == user_id,
|
|
||||||
CloudAgentConfig.enabled == True, # noqa: E712
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cloud_configs: list[CloudAgentConfig] = list(cloud_result.scalars().all())
|
|
||||||
|
|
||||||
# Build ordered list of overdue (type, config) pairs.
|
|
||||||
pending: list[tuple[str, Any]] = []
|
|
||||||
for cfg in local_configs:
|
|
||||||
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
|
|
||||||
pending.append(("local", cfg))
|
|
||||||
for cfg in cloud_configs:
|
|
||||||
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
|
|
||||||
pending.append(("cloud", cfg))
|
|
||||||
|
|
||||||
if not pending:
|
|
||||||
logger.debug("agent_runner: no overdue runs for user=%s", user_id)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"agent_runner: %d overdue run(s) to dispatch for user=%s", len(pending), user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
for agent_type, cfg in pending:
|
|
||||||
# Create a fresh run log for this scheduled dispatch.
|
|
||||||
run_log = AgentRunLog(
|
|
||||||
agent_id=cfg.id,
|
|
||||||
agent_type=agent_type,
|
|
||||||
user_id=user_id,
|
|
||||||
status="running",
|
|
||||||
)
|
|
||||||
async with async_session() as db:
|
|
||||||
db.add(run_log)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(run_log)
|
|
||||||
|
|
||||||
if agent_type == "local":
|
|
||||||
await run_local_agent(user_id, cfg, run_log, device_mgr)
|
|
||||||
else:
|
|
||||||
await run_cloud_agent(user_id, cfg, run_log, device_mgr)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Internal helper ─────────────────────────────────────────────────────────
|
# ── Internal helper ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -3,20 +3,15 @@
|
|||||||
Maintains in-memory state for all active Electron → backend WebSocket
|
Maintains in-memory state for all active Electron → backend WebSocket
|
||||||
connections. One connection per user (latest replaces previous).
|
connections. One connection per user (latest replaces previous).
|
||||||
|
|
||||||
The manager participates in two interaction patterns:
|
The manager handles the **tool-call round-trip** pattern:
|
||||||
|
- Backend sends ``tool_call`` frame → Electron executes the action →
|
||||||
1. **Tool-call round-trip** (bidirectional CRUD):
|
returns ``tool_result`` frame.
|
||||||
- Backend sends ``tool_call`` frame → Electron executes CRUD → returns
|
|
||||||
``tool_result`` frame.
|
|
||||||
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
||||||
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
||||||
receive the result dict from Electron.
|
receive the result dict from Electron.
|
||||||
|
|
||||||
2. **Agent-data streaming** (local directory agent runs):
|
This pattern is used by all tools (CRUD, file-system, etc.) via
|
||||||
- Backend sends ``agent_run`` frame → Electron reads files and sends
|
``execute_on_client()`` in ``ws_context.py``.
|
||||||
back a stream of ``agent_data`` frames followed by ``agent_complete``.
|
|
||||||
- ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for
|
|
||||||
a specific ``run_id`` so the agent runner can iterate frames.
|
|
||||||
|
|
||||||
The ``device_manager`` module-level singleton is imported by both the
|
The ``device_manager`` module-level singleton is imported by both the
|
||||||
device WS route and the agent runner.
|
device WS route and the agent runner.
|
||||||
@@ -42,8 +37,6 @@ class DeviceConnection:
|
|||||||
device_id: str
|
device_id: str
|
||||||
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
||||||
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||||
# Per-run queues for agent_data / agent_complete frames.
|
|
||||||
agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class DeviceConnectionManager:
|
class DeviceConnectionManager:
|
||||||
@@ -153,31 +146,6 @@ class DeviceConnectionManager:
|
|||||||
if fut is not None and not fut.done():
|
if fut is not None and not fut.done():
|
||||||
fut.set_result(result)
|
fut.set_result(result)
|
||||||
|
|
||||||
# ── Agent-data queue ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
def get_agent_data_queue(
|
|
||||||
self, user_id: str, run_id: str
|
|
||||||
) -> asyncio.Queue[dict | None]:
|
|
||||||
"""Return (creating if absent) the queue for *run_id* agent frames.
|
|
||||||
|
|
||||||
The agent runner reads from this queue. The device WS handler writes
|
|
||||||
to it. ``None`` is the sentinel that signals the stream is finished.
|
|
||||||
"""
|
|
||||||
conn = self._connections.get(user_id)
|
|
||||||
if conn is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"get_agent_data_queue: user {user_id!r} is not connected"
|
|
||||||
)
|
|
||||||
if run_id not in conn.agent_data_queues:
|
|
||||||
conn.agent_data_queues[run_id] = asyncio.Queue()
|
|
||||||
return conn.agent_data_queues[run_id]
|
|
||||||
|
|
||||||
def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None:
|
|
||||||
"""Remove the queue for *run_id* once a run has completed."""
|
|
||||||
conn = self._connections.get(user_id)
|
|
||||||
if conn:
|
|
||||||
conn.agent_data_queues.pop(run_id, None)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton — import this everywhere.
|
# Module-level singleton — import this everywhere.
|
||||||
device_manager = DeviceConnectionManager()
|
device_manager = DeviceConnectionManager()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
Every agent and the deep-agent supervisors call ``get_llm()`` or ``get_router_llm()``
|
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()``
|
||||||
instead of directly constructing a provider-specific class. The model string
|
instead of directly constructing a provider-specific class. The model string
|
||||||
follows the `LiteLLM model naming convention
|
follows the `LiteLLM model naming convention
|
||||||
<https://docs.litellm.ai/docs/providers>`_:
|
<https://docs.litellm.ai/docs/providers>`_:
|
||||||
@@ -18,6 +18,7 @@ Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import litellm
|
import litellm
|
||||||
@@ -32,6 +33,14 @@ from app.config.settings import settings
|
|||||||
# Drop them silently instead of raising UnsupportedParamsError.
|
# Drop them silently instead of raising UnsupportedParamsError.
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
|
# Some provider responses include a plain dict in the `usage` field where a
|
||||||
|
# richer Pydantic model is expected. This warning is noisy but non-fatal.
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||||
|
category=UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _api_key_for_model(model: str) -> str | None:
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
"""Return the most appropriate API key for the given LiteLLM model string."""
|
"""Return the most appropriate API key for the given LiteLLM model string."""
|
||||||
|
|||||||
@@ -43,15 +43,21 @@ _PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
|||||||
|
|
||||||
|
|
||||||
class MemoryMiddleware:
|
class MemoryMiddleware:
|
||||||
"""Enrich agent context with memory and persist interactions after."""
|
"""Enrich orchestrator context with memory and persist interactions after."""
|
||||||
|
|
||||||
def __init__(self, db: AsyncSession) -> None:
|
def __init__(self, db: AsyncSession) -> None:
|
||||||
self._db = db
|
self._db = db
|
||||||
|
|
||||||
# ── Public API ────────────────────────────────────────────────────────────
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
|
async def enrich_context(
|
||||||
"""Build memory context dict to inject into the agent before LLM call.
|
self,
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Build memory context dict to inject into the orchestrator before LLM call.
|
||||||
|
|
||||||
Returns a dict with keys:
|
Returns a dict with keys:
|
||||||
core_memory — {key: plaintext_value, ...}
|
core_memory — {key: plaintext_value, ...}
|
||||||
@@ -65,9 +71,21 @@ class MemoryMiddleware:
|
|||||||
|
|
||||||
core = await self._load_core(user_id, fernet)
|
core = await self._load_core(user_id, fernet)
|
||||||
associative = await self._load_associative(user_id, message, fernet)
|
associative = await self._load_associative(user_id, message, fernet)
|
||||||
episodic = await self._load_episodic(user_id, fernet)
|
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||||
proactive = await self._load_proactive(user_id, fernet)
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
logger.info(
|
||||||
|
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
user_dbg.get("tier") or "-",
|
||||||
|
len(core),
|
||||||
|
len(associative),
|
||||||
|
len(episodic),
|
||||||
|
len(proactive),
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"core_memory": core,
|
"core_memory": core,
|
||||||
"associative_memory": associative,
|
"associative_memory": associative,
|
||||||
@@ -81,6 +99,7 @@ class MemoryMiddleware:
|
|||||||
session_id: str,
|
session_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
response: str,
|
response: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Summarise and store a completed interaction in episodic memory.
|
"""Summarise and store a completed interaction in episodic memory.
|
||||||
|
|
||||||
@@ -103,11 +122,19 @@ class MemoryMiddleware:
|
|||||||
self._db.add(row)
|
self._db.add(row)
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
logger.info(
|
||||||
|
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
user_dbg.get("tier") or "-",
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
|
||||||
async def update_core(self, user_id: str, key: str, value: str) -> None:
|
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
||||||
"""Upsert a core memory key/value for a user."""
|
"""Upsert a core memory key/value for a user."""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
if fernet is None:
|
if fernet is None:
|
||||||
@@ -133,10 +160,176 @@ class MemoryMiddleware:
|
|||||||
))
|
))
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
logger.info(
|
||||||
|
"memory: update_core trace=%s user=%s tier=%s key=%s",
|
||||||
|
trace_id or "-",
|
||||||
|
user_id,
|
||||||
|
user_dbg.get("tier") or "-",
|
||||||
|
key,
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
||||||
|
"""Return core memory as editable blocks (label/value)."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore)
|
||||||
|
.where(MemoryCore.user_id == user_id)
|
||||||
|
.order_by(MemoryCore.key.asc())
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[dict[str, str]] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append({"label": row.key, "value": plaintext})
|
||||||
|
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
||||||
|
"""Return a single core memory block value by label."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(
|
||||||
|
MemoryCore.user_id == user_id,
|
||||||
|
MemoryCore.key == label,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
|
||||||
|
return None
|
||||||
|
value = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
|
||||||
|
return value
|
||||||
|
|
||||||
|
async def delete_core(self, user_id: str, label: str) -> bool:
|
||||||
|
"""Delete a core memory block by label. Returns True if deleted."""
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(
|
||||||
|
MemoryCore.user_id == user_id,
|
||||||
|
MemoryCore.key == label,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
|
||||||
|
return False
|
||||||
|
|
||||||
|
await self._db.delete(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info("memory: delete_core user=%s label=%s", user_id, label)
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
||||||
|
"""Append content to a core block, creating it if missing."""
|
||||||
|
current = await self.get_core_block(user_id, label)
|
||||||
|
if current is None:
|
||||||
|
await self.update_core(user_id, label, content)
|
||||||
|
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
|
||||||
|
return
|
||||||
|
await self.update_core(user_id, label, f"{current}\n{content}")
|
||||||
|
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
|
||||||
|
|
||||||
|
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
||||||
|
"""Replace one exact string inside a core block. Returns False if not found."""
|
||||||
|
current = await self.get_core_block(user_id, label)
|
||||||
|
if current is None or old not in current:
|
||||||
|
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
|
||||||
|
return False
|
||||||
|
await self.update_core(user_id, label, current.replace(old, new, 1))
|
||||||
|
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||||
|
"""Insert a long-term archival memory entry."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, content)
|
||||||
|
row = MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
content_encrypted=encrypted,
|
||||||
|
embedding=None,
|
||||||
|
entity_type=source,
|
||||||
|
entity_id=None,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||||
|
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative)
|
||||||
|
.where(MemoryAssociative.user_id == user_id)
|
||||||
|
.order_by(MemoryAssociative.updated_at.desc())
|
||||||
|
.limit(100)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
needle = query.strip().lower()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is None:
|
||||||
|
continue
|
||||||
|
if not needle or needle in plaintext.lower():
|
||||||
|
out.append(plaintext)
|
||||||
|
if len(out) >= max(top_k, 1):
|
||||||
|
break
|
||||||
|
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||||
|
"""Search recall memory (episodic summaries) by keyword."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryEpisodic)
|
||||||
|
.where(MemoryEpisodic.user_id == user_id)
|
||||||
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
|
.limit(100)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
needle = query.strip().lower()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||||
|
if plaintext is None:
|
||||||
|
continue
|
||||||
|
if not needle or needle in plaintext.lower():
|
||||||
|
out.append(plaintext)
|
||||||
|
if len(out) >= max(top_k, 1):
|
||||||
|
break
|
||||||
|
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||||
|
return out
|
||||||
|
|
||||||
# ── Private helpers ───────────────────────────────────────────────────────
|
# ── Private helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||||
@@ -148,6 +341,16 @@ class MemoryMiddleware:
|
|||||||
return None
|
return None
|
||||||
return Fernet(user.encryption_key.encode())
|
return Fernet(user.encryption_key.encode())
|
||||||
|
|
||||||
|
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
||||||
|
"""Load lightweight user debug fields for trace logs."""
|
||||||
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None:
|
||||||
|
return {"tier": None}
|
||||||
|
return {
|
||||||
|
"tier": user.tier,
|
||||||
|
}
|
||||||
|
|
||||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||||
@@ -183,10 +386,17 @@ class MemoryMiddleware:
|
|||||||
out.append(plaintext)
|
out.append(plaintext)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
|
async def _load_episodic(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
fernet: Fernet,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
||||||
|
if session_id:
|
||||||
|
query = query.where(MemoryEpisodic.session_id == session_id)
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryEpisodic)
|
query
|
||||||
.where(MemoryEpisodic.user_id == user_id)
|
|
||||||
.order_by(MemoryEpisodic.created_at.desc())
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
.limit(_EPISODIC_RECENT_N)
|
.limit(_EPISODIC_RECENT_N)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,141 +1,47 @@
|
|||||||
"""Output Formatter — transforms deep-agent event streams into WS frame sequences.
|
"""Output formatter for deep-agent stream events."""
|
||||||
|
|
||||||
Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``:
|
|
||||||
* ``("token", str)`` — supervisor text token
|
|
||||||
* ``("tool_end", dict)`` — sub-agent finished: ``{name, result}``
|
|
||||||
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
|
|
||||||
|
|
||||||
HomeFormatter:
|
|
||||||
* Streams text tokens as-is → emits ``WsStreamText``
|
|
||||||
(text may contain inline ``<type>[id,...]</type>`` entity tags
|
|
||||||
for the frontend to parse and render as interactive components)
|
|
||||||
* Attaches mutations → injects into ``WsStreamEnd``
|
|
||||||
|
|
||||||
FloatingFormatter:
|
|
||||||
* Sniffs first ``tool_end`` name → emits ``WsFloatingDomain``
|
|
||||||
* Streams text tokens → emits ``WsStreamText``
|
|
||||||
* Attaches mutations → injects into ``WsStreamEnd``
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.schemas import (
|
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
||||||
WsFloatingDomain,
|
|
||||||
WsStreamEnd,
|
|
||||||
WsStreamStart,
|
|
||||||
WsStreamText,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Map sub-agent tool name → floating domain / entity type
|
|
||||||
_AGENT_DOMAIN: dict[str, str] = {
|
|
||||||
"task_agent": "tasks",
|
|
||||||
"timeline_agent": "timelines",
|
|
||||||
"note_agent": "notes",
|
|
||||||
"project_agent": "projects",
|
|
||||||
}
|
|
||||||
|
|
||||||
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||||
|
|
||||||
|
|
||||||
class HomeFormatter:
|
class StreamFormatter:
|
||||||
"""Consumes a deep-agent event stream and yields WS frames for the Home view.
|
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
||||||
|
|
||||||
Text tokens are forwarded as-is via ``WsStreamText``. The supervisor
|
|
||||||
embeds ``<type>[id1,id2]</type>`` entity tags inline — the frontend
|
|
||||||
is responsible for parsing those and rendering interactive components.
|
|
||||||
Mutations are attached to ``WsStreamEnd``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, request_id: str) -> None:
|
def __init__(self, request_id: str) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self._mutations: list[dict] = []
|
|
||||||
|
|
||||||
async def format(
|
async def format(
|
||||||
self,
|
self,
|
||||||
event_stream: AsyncGenerator[tuple[str, Any], None],
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
) -> AsyncGenerator[WsFrame, None]:
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
started = False
|
||||||
|
|
||||||
async for event_type, data in event_stream:
|
async for event_type, data in event_stream:
|
||||||
if event_type == "token":
|
if event_type == "floating_domain":
|
||||||
if data:
|
if isinstance(data, dict):
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=data)
|
|
||||||
|
|
||||||
elif event_type == "mutations":
|
|
||||||
self._mutations = data or []
|
|
||||||
|
|
||||||
yield WsStreamEnd(
|
|
||||||
request_id=self.request_id,
|
|
||||||
mutations=[
|
|
||||||
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
|
||||||
for m in self._mutations
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FloatingFormatter:
|
|
||||||
"""Consumes a deep-agent event stream and yields WS frames for the Floating view.
|
|
||||||
|
|
||||||
Sniffs the first ``tool_end`` event name to derive the domain (e.g.
|
|
||||||
``task_agent`` → ``"tasks"``), then streams text tokens as plain
|
|
||||||
``WsStreamText``. No block parsing for floating context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, request_id: str) -> None:
|
|
||||||
self.request_id = request_id
|
|
||||||
self._mutations: list[dict] = []
|
|
||||||
|
|
||||||
async def format(
|
|
||||||
self,
|
|
||||||
event_stream: AsyncGenerator[tuple[str, Any], None],
|
|
||||||
) -> AsyncGenerator[WsFrame, None]:
|
|
||||||
domain_sent = False
|
|
||||||
|
|
||||||
async for event_type, data in event_stream:
|
|
||||||
if event_type == "tool_end" and not domain_sent:
|
|
||||||
# Sniff domain from the first sub-agent that completes
|
|
||||||
name = data.get("name", "")
|
|
||||||
domain = _AGENT_DOMAIN.get(name, "tasks")
|
|
||||||
yield WsFloatingDomain(
|
yield WsFloatingDomain(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
domain=domain, # type: ignore[arg-type]
|
domain=data,
|
||||||
)
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if event_type != "token":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not started:
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
domain_sent = True
|
started = True
|
||||||
|
|
||||||
elif event_type == "token":
|
text = str(data or "")
|
||||||
if not domain_sent:
|
if text:
|
||||||
# First token arrived before any tool_end — default domain
|
yield WsStreamText(request_id=self.request_id, chunk=text)
|
||||||
yield WsFloatingDomain(
|
|
||||||
request_id=self.request_id,
|
if not started:
|
||||||
domain="tasks", # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
domain_sent = True
|
yield WsStreamEnd(request_id=self.request_id)
|
||||||
if data:
|
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=data)
|
|
||||||
|
|
||||||
elif event_type == "mutations":
|
|
||||||
self._mutations = data or []
|
|
||||||
|
|
||||||
# If no events triggered domain_sent (edge case), still emit structure
|
|
||||||
if not domain_sent:
|
|
||||||
yield WsFloatingDomain(
|
|
||||||
request_id=self.request_id,
|
|
||||||
domain="tasks", # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
|
||||||
|
|
||||||
yield WsStreamEnd(
|
|
||||||
request_id=self.request_id,
|
|
||||||
mutations=[
|
|
||||||
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
|
||||||
for m in self._mutations
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -7,21 +7,18 @@ The callback sends a `tool_call` WS frame and awaits the `tool_result`.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import Any, Callable, Coroutine
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Holds the execute callback for the current WS session.
|
# Holds the execute callback for the current WS session.
|
||||||
# Set by the chat WS handler before the deep agent runs; cleared after.
|
# Set by the chat WS handler before the orchestrator runs; cleared after.
|
||||||
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
||||||
"_client_executor"
|
"_client_executor"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optional collector that captures raw execute_on_client results.
|
# Optional collector that captures raw execute_on_client results.
|
||||||
# Set by the deep agent tool loop to capture CRUD mutations.
|
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
|
||||||
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||||
"_tool_result_collector", default=None
|
"_tool_result_collector", default=None
|
||||||
)
|
)
|
||||||
@@ -84,17 +81,12 @@ async def execute_on_client(
|
|||||||
if limit is not None:
|
if limit is not None:
|
||||||
payload["limit"] = limit
|
payload["limit"] = limit
|
||||||
|
|
||||||
logger.info("execute_on_client: sending payload action=%s table=%s id=%s", action, table, payload["id"])
|
|
||||||
result = await callback(payload)
|
result = await callback(payload)
|
||||||
if result is None:
|
|
||||||
logger.error("execute_on_client: callback returned None for action=%s table=%s id=%s", action, table, payload["id"])
|
|
||||||
else:
|
|
||||||
logger.info("execute_on_client: got result type=%s keys=%s", type(result).__name__, list(result.keys()) if isinstance(result, dict) else "N/A")
|
|
||||||
collector = _tool_result_collector.get(None)
|
collector = _tool_result_collector.get(None)
|
||||||
if collector is not None and action in ("insert", "update", "delete"):
|
if collector is not None:
|
||||||
collector.append({
|
collector.append({
|
||||||
"action": action,
|
"action": action,
|
||||||
"table": table,
|
"table": table,
|
||||||
"data": data or {},
|
"data": result,
|
||||||
})
|
})
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -18,7 +18,9 @@ from app.config.settings import settings
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Startup: initialise DB connection pool
|
# Startup: ensure agent tool modules are loaded.
|
||||||
|
import app.agents # noqa: F401
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Shutdown: dispose SQLAlchemy connection pool
|
# Shutdown: dispose SQLAlchemy connection pool
|
||||||
@@ -48,7 +50,7 @@ def create_app() -> FastAPI:
|
|||||||
app.add_middleware(SanitizerMiddleware)
|
app.add_middleware(SanitizerMiddleware)
|
||||||
app.add_middleware(TierRateLimitMiddleware)
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
|
from app.api.routes import agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api/v1")
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
app.include_router(chat.router, prefix="/api/v1")
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
@@ -58,7 +60,6 @@ def create_app() -> FastAPI:
|
|||||||
app.include_router(plugins.router, prefix="/api/v1")
|
app.include_router(plugins.router, prefix="/api/v1")
|
||||||
app.include_router(billing.router, prefix="/api/v1")
|
app.include_router(billing.router, prefix="/api/v1")
|
||||||
app.include_router(agents.router, prefix="/api/v1")
|
app.include_router(agents.router, prefix="/api/v1")
|
||||||
app.include_router(agent_setup.router, prefix="/api/v1")
|
|
||||||
app.include_router(device_ws.router, prefix="/api/v1")
|
app.include_router(device_ws.router, prefix="/api/v1")
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
|
|||||||
145
app/schemas.py
145
app/schemas.py
@@ -142,9 +142,6 @@ class WsFrameType(str, Enum):
|
|||||||
tool_result = "tool_result"
|
tool_result = "tool_result"
|
||||||
final = "final"
|
final = "final"
|
||||||
ping = "ping"
|
ping = "ping"
|
||||||
agent_run = "agent_run"
|
|
||||||
agent_data = "agent_data"
|
|
||||||
agent_complete = "agent_complete"
|
|
||||||
device_hello = "device_hello"
|
device_hello = "device_hello"
|
||||||
# ── v3 frame types ─────────────────────────────────────────────────
|
# ── v3 frame types ─────────────────────────────────────────────────
|
||||||
home_request = "home_request"
|
home_request = "home_request"
|
||||||
@@ -156,6 +153,10 @@ class WsFrameType(str, Enum):
|
|||||||
data_request = "data_request"
|
data_request = "data_request"
|
||||||
data_response = "data_response"
|
data_response = "data_response"
|
||||||
mutation = "mutation"
|
mutation = "mutation"
|
||||||
|
# ── v4 journey frame types ────────────────────────────────────────
|
||||||
|
journey_start = "journey_start"
|
||||||
|
journey_message = "journey_message"
|
||||||
|
journey_reply = "journey_reply"
|
||||||
|
|
||||||
|
|
||||||
class WsToolCall(BaseModel):
|
class WsToolCall(BaseModel):
|
||||||
@@ -208,31 +209,6 @@ class WsDeviceHello(BaseModel):
|
|||||||
agent_ids: list[str] = Field(default_factory=list)
|
agent_ids: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class WsAgentRun(BaseModel):
|
|
||||||
"""Server → Client: trigger an agent run on the connected device."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.agent_run] = WsFrameType.agent_run
|
|
||||||
run_id: str
|
|
||||||
agent_id: str
|
|
||||||
config: dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
class WsAgentData(BaseModel):
|
|
||||||
"""Client → Server: files read by the local agent."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.agent_data] = WsFrameType.agent_data
|
|
||||||
run_id: str
|
|
||||||
files: list[dict[str, Any]]
|
|
||||||
|
|
||||||
|
|
||||||
class WsAgentComplete(BaseModel):
|
|
||||||
"""Client → Server: Electron signals it has finished reading files."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.agent_complete] = WsFrameType.agent_complete
|
|
||||||
run_id: str
|
|
||||||
files_read: int
|
|
||||||
errors: list[str] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
||||||
|
|
||||||
@@ -279,7 +255,14 @@ class WsStreamEnd(BaseModel):
|
|||||||
|
|
||||||
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||||
request_id: str
|
request_id: str
|
||||||
mutations: list[dict[str, Any]] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
class WsDomain(BaseModel):
|
||||||
|
"""Structured floating domain payload for UI routing decisions."""
|
||||||
|
|
||||||
|
type: Literal["task", "timeline", "project", "node"]
|
||||||
|
id: str | None = None
|
||||||
|
section: Literal["task", "timeline", "note"] | None = None
|
||||||
|
|
||||||
|
|
||||||
class WsFloatingDomain(BaseModel):
|
class WsFloatingDomain(BaseModel):
|
||||||
@@ -287,7 +270,7 @@ class WsFloatingDomain(BaseModel):
|
|||||||
|
|
||||||
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
||||||
request_id: str
|
request_id: str
|
||||||
domain: Literal["tasks", "timelines", "notes", "projects"]
|
domain: WsDomain
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Catalog ─────────────────────────────────────────────────────
|
# ── Agent Catalog ─────────────────────────────────────────────────────
|
||||||
@@ -296,84 +279,27 @@ class AgentCatalogItem(BaseModel):
|
|||||||
type: str
|
type: str
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
config_schema: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Local Agent Config ────────────────────────────────────────────────
|
class AgentCreationCheckRequest(BaseModel):
|
||||||
|
active_agents: int = Field(ge=0, default=0)
|
||||||
class LocalAgentConfigCreate(BaseModel):
|
|
||||||
name: str
|
|
||||||
device_id: str
|
|
||||||
directory_paths: list[str]
|
|
||||||
data_types: list[str]
|
|
||||||
prompt_template: str
|
|
||||||
file_extensions: list[str]
|
|
||||||
schedule_cron: str
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAgentConfigUpdate(BaseModel):
|
class AgentCreationCheckResponse(BaseModel):
|
||||||
name: str | None = None
|
allowed: bool
|
||||||
device_id: str | None = None
|
tier: BillingTier
|
||||||
directory_paths: list[str] | None = None
|
active_agents: int
|
||||||
data_types: list[str] | None = None
|
limit: int
|
||||||
prompt_template: str | None = None
|
|
||||||
file_extensions: list[str] | None = None
|
|
||||||
schedule_cron: str | None = None
|
|
||||||
enabled: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAgentConfigResponse(BaseModel):
|
class AgentTriggerRequest(BaseModel):
|
||||||
id: str
|
directory: str = Field(min_length=1)
|
||||||
name: str
|
device_id: str = Field(default="")
|
||||||
device_id: str
|
what_to_extract: list[str] = Field(min_length=1)
|
||||||
directory_paths: list[str]
|
actions_by_type: dict[str, list[str]] | None = None
|
||||||
data_types: list[str]
|
batch_interval: str = Field(min_length=1)
|
||||||
prompt_template: str
|
custom_agent_prompt: str = Field(min_length=1)
|
||||||
file_extensions: list[str]
|
active_agents: int = Field(ge=0, default=0)
|
||||||
schedule_cron: str
|
|
||||||
enabled: bool
|
|
||||||
last_run_at: int | None
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud Agent Config ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class CloudAgentConfigCreate(BaseModel):
|
|
||||||
provider: Literal["gmail", "teams", "outlook"]
|
|
||||||
name: str
|
|
||||||
data_types: list[str]
|
|
||||||
prompt_template: str
|
|
||||||
oauth_token_encrypted: str
|
|
||||||
schedule_cron: str
|
|
||||||
filter_config: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class CloudAgentConfigUpdate(BaseModel):
|
|
||||||
provider: Literal["gmail", "teams", "outlook"] | None = None
|
|
||||||
name: str | None = None
|
|
||||||
data_types: list[str] | None = None
|
|
||||||
prompt_template: str | None = None
|
|
||||||
oauth_token_encrypted: str | None = None
|
|
||||||
schedule_cron: str | None = None
|
|
||||||
filter_config: dict[str, Any] | None = None
|
|
||||||
enabled: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class CloudAgentConfigResponse(BaseModel):
|
|
||||||
"""oauth_token_encrypted is intentionally excluded — never returned to clients."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
provider: Literal["gmail", "teams", "outlook"]
|
|
||||||
name: str
|
|
||||||
data_types: list[str]
|
|
||||||
prompt_template: str
|
|
||||||
schedule_cron: str
|
|
||||||
filter_config: dict[str, Any] | None
|
|
||||||
enabled: bool
|
|
||||||
last_run_at: int | None
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Run Log ─────────────────────────────────────────────────────
|
# ── Agent Run Log ─────────────────────────────────────────────────────
|
||||||
@@ -392,18 +318,3 @@ class AgentRunLogResponse(BaseModel):
|
|||||||
|
|
||||||
# ── Chatbot Journey ───────────────────────────────────────────────────
|
# ── Chatbot Journey ───────────────────────────────────────────────────
|
||||||
|
|
||||||
class JourneyStartRequest(BaseModel):
|
|
||||||
agent_type: Literal["local", "cloud"]
|
|
||||||
agent_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class JourneyMessageRequest(BaseModel):
|
|
||||||
session_id: str
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
class JourneyResponse(BaseModel):
|
|
||||||
session_id: str
|
|
||||||
message: str
|
|
||||||
done: bool
|
|
||||||
prompt_template: str | None = None
|
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ gunicorn>=22.0.0
|
|||||||
langchain>=0.3.0
|
langchain>=0.3.0
|
||||||
langchain-openai>=0.3.0
|
langchain-openai>=0.3.0
|
||||||
langchain-litellm>=0.1.0
|
langchain-litellm>=0.1.0
|
||||||
langgraph>=0.3.0
|
|
||||||
deepagents>=0.4.10
|
|
||||||
litellm>=1.50.0
|
litellm>=1.50.0
|
||||||
pydantic>=2.10.0
|
pydantic>=2.10.0
|
||||||
pydantic-settings>=2.7.0
|
pydantic-settings>=2.7.0
|
||||||
|
|||||||
@@ -10,13 +10,13 @@ Coverage:
|
|||||||
- run_local_agent — file-read timeout path
|
- run_local_agent — file-read timeout path
|
||||||
- run_local_agent — LLM extraction error path
|
- run_local_agent — LLM extraction error path
|
||||||
- run_cloud_agent — stub returns error immediately
|
- run_cloud_agent — stub returns error immediately
|
||||||
- trigger_pending_runs — overdue local + cloud dispatched
|
- trigger_pending_runs — skipped when config is client-owned
|
||||||
- trigger_pending_runs — non-overdue skipped
|
- trigger_pending_runs — non-overdue skipped
|
||||||
- trigger_pending_runs — device_id filter for local agents
|
- trigger_pending_runs — device_id filter for local agents
|
||||||
|
|
||||||
Integration:
|
Integration:
|
||||||
- POST /agents/{id}/run — 404 on unknown agent
|
- POST /agents/can-create — billing eligibility check
|
||||||
- POST /agents/{id}/run — creates run log + dispatches background task
|
- POST /agents/trigger — creates run log + dispatches background task
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -373,7 +373,7 @@ async def test_run_local_agent_happy_path():
|
|||||||
assert kwargs["items_processed"] == 1
|
assert kwargs["items_processed"] == 1
|
||||||
assert kwargs["items_created"] == 1
|
assert kwargs["items_created"] == 1
|
||||||
assert kwargs["errors"] == []
|
assert kwargs["errors"] == []
|
||||||
assert kwargs["update_config_last_run"] is True
|
assert kwargs["update_config_last_run"] is False
|
||||||
|
|
||||||
# Verify agent_run frame was sent.
|
# Verify agent_run frame was sent.
|
||||||
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
||||||
@@ -690,31 +690,11 @@ async def test_finalize_run_updates_cloud_config_last_run_at():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_pending_runs_no_overdue():
|
async def test_trigger_pending_runs_no_overdue():
|
||||||
"""If no agents are overdue trigger_pending_runs does nothing."""
|
"""Pending-run scan is skipped because agent config is client-owned."""
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
config = _make_local_config()
|
|
||||||
config.last_run_at = datetime.now(timezone.utc) - timedelta(minutes=30) # ran 30m ago
|
|
||||||
config.schedule_cron = "0 */6 * * *" # every 6h — not due yet
|
|
||||||
|
|
||||||
mock_db_result_local = MagicMock()
|
|
||||||
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
|
||||||
|
|
||||||
mock_db_result_cloud = MagicMock()
|
|
||||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
|
||||||
|
|
||||||
mgr = _make_manager()
|
mgr = _make_manager()
|
||||||
|
|
||||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
|
||||||
mock_ctx = AsyncMock()
|
|
||||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
|
||||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_ctx.execute = AsyncMock(
|
|
||||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
|
||||||
)
|
|
||||||
mock_session_factory.return_value = mock_ctx
|
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
mock_run.assert_not_called()
|
mock_run.assert_not_called()
|
||||||
@@ -722,31 +702,11 @@ async def test_trigger_pending_runs_no_overdue():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_pending_runs_device_id_filter():
|
async def test_trigger_pending_runs_device_id_filter():
|
||||||
"""Local agents are only triggered for the matching device_id."""
|
"""Device filtering is no longer backend-managed in pending runs."""
|
||||||
# The DB query already filters by device_id, so we verify the SELECT
|
|
||||||
# includes the device_id filter by checking that a config bound to a
|
|
||||||
# different device is never dispatched.
|
|
||||||
#
|
|
||||||
# Since trigger_pending_runs queries with device_id == "dev-001",
|
|
||||||
# simulate the DB returning an empty list (as it would for a mismatch).
|
|
||||||
mock_db_result_local = MagicMock()
|
|
||||||
mock_db_result_local.scalars.return_value.all.return_value = [] # no match
|
|
||||||
|
|
||||||
mock_db_result_cloud = MagicMock()
|
|
||||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
|
||||||
|
|
||||||
mgr = _make_manager(device_id="dev-001")
|
mgr = _make_manager(device_id="dev-001")
|
||||||
|
|
||||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
|
||||||
mock_ctx = AsyncMock()
|
|
||||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
|
||||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_ctx.execute = AsyncMock(
|
|
||||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
|
||||||
)
|
|
||||||
mock_session_factory.return_value = mock_ctx
|
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
mock_run.assert_not_called()
|
mock_run.assert_not_called()
|
||||||
@@ -754,56 +714,18 @@ async def test_trigger_pending_runs_device_id_filter():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_pending_runs_dispatches_overdue():
|
async def test_trigger_pending_runs_dispatches_overdue():
|
||||||
"""Overdue local agent triggers run_local_agent sequentially."""
|
"""No pending runs are dispatched by backend after config deprecation."""
|
||||||
config = _make_local_config() # last_run_at=None → always overdue
|
|
||||||
|
|
||||||
mock_db_result_local = MagicMock()
|
|
||||||
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
|
||||||
|
|
||||||
mock_db_result_cloud = MagicMock()
|
|
||||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
|
||||||
|
|
||||||
mgr = _make_manager()
|
mgr = _make_manager()
|
||||||
|
|
||||||
call_order: list[str] = []
|
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
|
|
||||||
async def _mock_run_local(user_id, cfg, run_log, device_mgr):
|
|
||||||
call_order.append("run_local")
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
|
||||||
patch("app.core.agent_runner.run_local_agent", side_effect=_mock_run_local):
|
|
||||||
# First call: query configs. Subsequent calls: create run_log.
|
|
||||||
mock_query_ctx = AsyncMock()
|
|
||||||
mock_query_ctx.__aenter__ = AsyncMock(return_value=mock_query_ctx)
|
|
||||||
mock_query_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_query_ctx.execute = AsyncMock(
|
|
||||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
|
||||||
)
|
|
||||||
|
|
||||||
run_log_obj = AgentRunLog(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
agent_id=config.id,
|
|
||||||
agent_type="local",
|
|
||||||
user_id=_FREE_UID,
|
|
||||||
status="running",
|
|
||||||
started_at=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
mock_insert_ctx = AsyncMock()
|
|
||||||
mock_insert_ctx.__aenter__ = AsyncMock(return_value=mock_insert_ctx)
|
|
||||||
mock_insert_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
||||||
mock_insert_ctx.add = MagicMock()
|
|
||||||
mock_insert_ctx.commit = AsyncMock()
|
|
||||||
mock_insert_ctx.refresh = AsyncMock(side_effect=lambda obj: None)
|
|
||||||
|
|
||||||
mock_session_factory.side_effect = [mock_query_ctx, mock_insert_ctx]
|
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
assert call_order == ["run_local"]
|
mock_run.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Integration: POST /agents/{id}/run
|
# Integration: POST /agents/can-create and /agents/trigger
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -820,50 +742,67 @@ def _override_db(db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_run_unknown_agent(client):
|
async def test_can_create_agent_allows_when_under_limit(client):
|
||||||
"""POST /agents/{id}/run returns 404 for unknown agent id."""
|
"""POST /agents/can-create returns allowed=True when under tier limit."""
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
f"/api/v1/agents/{uuid.uuid4()}/run",
|
"/api/v1/agents/can-create",
|
||||||
headers=auth_header("power"),
|
json={"active_agents": 0},
|
||||||
|
headers=auth_header("free"),
|
||||||
)
|
)
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["allowed"] is True
|
||||||
|
assert body["tier"] == "free"
|
||||||
|
assert body["active_agents"] == 0
|
||||||
|
assert body["limit"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_can_create_agent_denies_when_at_limit(client):
|
||||||
|
"""POST /agents/can-create returns allowed=False at free-tier limit."""
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/agents/can-create",
|
||||||
|
json={"active_agents": 2},
|
||||||
|
headers=auth_header("free"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["allowed"] is False
|
||||||
|
assert body["limit"] == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
||||||
"""POST /agents/{id}/run creates a run log and dispatches a background task."""
|
"""POST /agents/trigger creates a local run log and dispatches background task."""
|
||||||
# Create the local agent config in the DB.
|
dispatched: list[tuple[str, str]] = []
|
||||||
config = LocalAgentConfig(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=TEST_USER_IDS["power"],
|
|
||||||
device_id="dev-001",
|
|
||||||
name="My Agent",
|
|
||||||
directory_paths=["/home/user/docs"],
|
|
||||||
data_types=["tasks"],
|
|
||||||
prompt_template="Extract tasks.",
|
|
||||||
file_extensions=[".txt"],
|
|
||||||
schedule_cron="0 */6 * * *",
|
|
||||||
enabled=True,
|
|
||||||
)
|
|
||||||
db_session.add(config)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
dispatched: list = []
|
|
||||||
|
|
||||||
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
||||||
dispatched.append((user_id, cfg.id))
|
dispatched.append((user_id, cfg.id))
|
||||||
|
|
||||||
|
def _fake_create_task(coro):
|
||||||
|
coro.close()
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
||||||
patch("app.api.routes.agents.run_cloud_agent", new_callable=AsyncMock), \
|
|
||||||
patch("asyncio.create_task") as mock_create_task:
|
patch("asyncio.create_task") as mock_create_task:
|
||||||
|
mock_create_task.side_effect = _fake_create_task
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
f"/api/v1/agents/{config.id}/run",
|
"/api/v1/agents/trigger",
|
||||||
|
json={
|
||||||
|
"directory": "/home/user/docs",
|
||||||
|
"what_to_extract": ["task", "note"],
|
||||||
|
"actions_by_type": {"task": ["add", "update"], "note": ["add"]},
|
||||||
|
"batch_interval": "0 */6 * * *",
|
||||||
|
"custom_agent_prompt": "Extract tasks and notes.",
|
||||||
|
"active_agents": 0,
|
||||||
|
},
|
||||||
headers=auth_header("power"),
|
headers=auth_header("power"),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert resp.status_code == 202
|
assert resp.status_code == 202
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
assert data["agent_id"] == config.id
|
assert isinstance(data["agent_id"], str)
|
||||||
|
assert data["agent_id"]
|
||||||
assert data["status"] == "running"
|
assert data["status"] == "running"
|
||||||
assert data["agent_type"] == "local"
|
assert data["agent_type"] == "local"
|
||||||
|
|
||||||
|
|||||||
288
tests/test_deep_agent.py
Normal file
288
tests/test_deep_agent.py
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
"""Unit tests for single-agent deep_agent flows with mocked tool results."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date, timedelta
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
|
||||||
|
from app.core.deep_agent import (
|
||||||
|
_infer_floating_domain,
|
||||||
|
_normalize_tagged_list_lines,
|
||||||
|
run_floating,
|
||||||
|
run_floating_stream,
|
||||||
|
run_home,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeTool:
|
||||||
|
name = "list_tasks"
|
||||||
|
|
||||||
|
async def ainvoke(self, args):
|
||||||
|
return {"rows": [{"id": "task-1", "title": "Mock Task"}], "echo": args}
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeLLM:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.agent_calls = 0
|
||||||
|
|
||||||
|
def bind_tools(self, _tools):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def ainvoke(self, messages):
|
||||||
|
system_prompt = str(getattr(messages[0], "content", "")) if messages else ""
|
||||||
|
if "strict domain classifier" in system_prompt:
|
||||||
|
return AIMessage(content='{"type":"timeline","id":"tl-1","section":null}')
|
||||||
|
|
||||||
|
self.agent_calls += 1
|
||||||
|
if self.agent_calls == 1:
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": "call-1",
|
||||||
|
"name": "list_tasks",
|
||||||
|
"args": {"project_id": "proj-1"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
|
||||||
|
assert tool_messages, "Expected at least one tool message"
|
||||||
|
return AIMessage(content=f"Final answer from mocked tool: {tool_messages[-1].content}")
|
||||||
|
|
||||||
|
async def astream(self, _messages):
|
||||||
|
yield SimpleNamespace(content="stream-")
|
||||||
|
yield SimpleNamespace(content="ok")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_home_uses_mocked_tool_result():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||||
|
):
|
||||||
|
out = await run_home("user-1", "list my tasks", {})
|
||||||
|
|
||||||
|
assert "Final answer from mocked tool" in out
|
||||||
|
assert "Mock Task" in out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in run_floating_stream(
|
||||||
|
"user-1",
|
||||||
|
"show me timeline updates",
|
||||||
|
{"scope": {"type": "timeline", "id": "tl-1"}},
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
assert events[0] == (
|
||||||
|
"floating_domain",
|
||||||
|
{"type": "timeline", "id": "tl-1", "section": None},
|
||||||
|
)
|
||||||
|
assert ("token", "stream-") in events
|
||||||
|
assert ("token", "ok") in events
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
|
||||||
|
class _ClassifierOnlyLLM:
|
||||||
|
async def ainvoke(self, _messages):
|
||||||
|
return AIMessage(
|
||||||
|
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=_ClassifierOnlyLLM()):
|
||||||
|
domain = await _infer_floating_domain(
|
||||||
|
"Quali sono i miei task per il progetto X",
|
||||||
|
{
|
||||||
|
"scope": {"type": "timeline"},
|
||||||
|
"resolved_project_id": "213213-312321-312312-421321",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert domain == {
|
||||||
|
"type": "project",
|
||||||
|
"id": "213213-312321-312312-421321",
|
||||||
|
"section": "task",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
|
||||||
|
raw = (
|
||||||
|
"Certo!\n\n"
|
||||||
|
"1. **Task A** — priorita high <task>[task-1]</task>\n"
|
||||||
|
"2. **Task B** — priorita medium <task>[task-2]</task>\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
out = _normalize_tagged_list_lines(raw, "quali sono le prossime attivita?")
|
||||||
|
|
||||||
|
assert "<task>[task-1]</task>" in out
|
||||||
|
assert "<task>[task-2]</task>" in out
|
||||||
|
assert "Task A" not in out
|
||||||
|
assert "Task B" not in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_month_future_only():
|
||||||
|
today = date.today()
|
||||||
|
tomorrow = today + timedelta(days=1)
|
||||||
|
yesterday = today - timedelta(days=1)
|
||||||
|
next_month = (today.replace(day=28) + timedelta(days=5)).replace(day=1)
|
||||||
|
|
||||||
|
raw = "\n".join(
|
||||||
|
[
|
||||||
|
f"- Milestone old — {yesterday.strftime('%d/%m/%Y')} <timeline>[tl-old]</timeline>",
|
||||||
|
f"- Milestone next — {tomorrow.strftime('%d/%m/%Y')} <timeline>[tl-next]</timeline>",
|
||||||
|
f"- Milestone future — {next_month.strftime('%d/%m/%Y')} <timeline>[tl-future]</timeline>",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
out = _normalize_tagged_list_lines(raw, "invece i miei eventi prossimi?")
|
||||||
|
|
||||||
|
assert "<timeline>[tl-next]</timeline>" in out
|
||||||
|
assert "<timeline>[tl-old]</timeline>" not in out
|
||||||
|
assert "<timeline>[tl-future]</timeline>" not in out
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_strips_xml_like_tags_from_final_text():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
async def _fake_run_single_agent(**_kwargs):
|
||||||
|
return (
|
||||||
|
"Hai 1 task:\\n"
|
||||||
|
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
||||||
|
):
|
||||||
|
text, _domain = await run_floating(
|
||||||
|
"user-1",
|
||||||
|
"quali task ho?",
|
||||||
|
{"scope": {"type": "task"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "<task>" not in text
|
||||||
|
assert "</task>" not in text
|
||||||
|
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
async def _fake_stream(**_kwargs):
|
||||||
|
yield "token", "Hai 1 task:\\n"
|
||||||
|
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in run_floating_stream(
|
||||||
|
"user-1",
|
||||||
|
"quali task ho?",
|
||||||
|
{"scope": {"type": "task"}},
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
token_events = [str(data) for event_type, data in events if event_type == "token"]
|
||||||
|
combined = "".join(token_events)
|
||||||
|
assert "<task>" not in combined
|
||||||
|
assert "</task>" not in combined
|
||||||
|
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty():
|
||||||
|
class _NoChunkLLM:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
|
def bind_tools(self, _tools):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def ainvoke(self, _messages):
|
||||||
|
self.calls += 1
|
||||||
|
if self.calls == 1:
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": "call-1",
|
||||||
|
"name": "list_tasks",
|
||||||
|
"args": {},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return AIMessage(content="No notes found.")
|
||||||
|
|
||||||
|
async def astream(self, _messages):
|
||||||
|
if False:
|
||||||
|
yield None
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=_NoChunkLLM()), patch(
|
||||||
|
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in run_floating_stream(
|
||||||
|
"user-1",
|
||||||
|
"quali sono le note?",
|
||||||
|
{"scope": {"type": "note"}},
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
assert events[0][0] == "floating_domain"
|
||||||
|
assert ("token", "No notes found.") in events
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_returns_fallback_when_sanitization_would_empty_text():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
async def _fake_run_single_agent(**_kwargs):
|
||||||
|
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
||||||
|
):
|
||||||
|
text, _domain = await run_floating(
|
||||||
|
"user-1",
|
||||||
|
"quali task ho?",
|
||||||
|
{"scope": {"type": "task"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert text == "No results found."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text():
|
||||||
|
fake_llm = _FakeLLM()
|
||||||
|
|
||||||
|
async def _fake_stream(**_kwargs):
|
||||||
|
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||||
|
|
||||||
|
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||||
|
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in run_floating_stream(
|
||||||
|
"user-1",
|
||||||
|
"quali task ho?",
|
||||||
|
{"scope": {"type": "task"}},
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
assert ("token", "No results found.") in events
|
||||||
@@ -110,6 +110,32 @@ async def test_enrich_context_returns_episodic_memory(db_session, user_with_key)
|
|||||||
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
|
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_filters_episodic_by_session_id(db_session, user_with_key):
|
||||||
|
target_session = str(uuid.uuid4())
|
||||||
|
other_session = str(uuid.uuid4())
|
||||||
|
db_session.add(MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
summary_encrypted=_enc("Target session memory"),
|
||||||
|
session_id=target_session,
|
||||||
|
))
|
||||||
|
db_session.add(MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
summary_encrypted=_enc("Other session memory"),
|
||||||
|
session_id=other_session,
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "any message", session_id=target_session)
|
||||||
|
|
||||||
|
episodic = ctx.get("episodic_memory", [])
|
||||||
|
assert any("Target session" in s for s in episodic)
|
||||||
|
assert not any("Other session" in s for s in episodic)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
||||||
# Add one pattern above threshold and one below
|
# Add one pattern above threshold and one below
|
||||||
@@ -229,6 +255,40 @@ async def test_update_core_upsert(db_session, user_with_key):
|
|||||||
assert _dec(rows[0].value_encrypted) == "fr"
|
assert _dec(rows[0].value_encrypted) == "fr"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_core_block_edit_ops(db_session, user_with_key):
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
|
||||||
|
await middleware.update_core(USER_ID, "human", "Name: Roberto")
|
||||||
|
await middleware.append_core(USER_ID, "human", "Timezone: Europe/Rome")
|
||||||
|
replaced = await middleware.replace_core(USER_ID, "human", "Roberto", "Robert")
|
||||||
|
|
||||||
|
blocks = await middleware.list_core_blocks(USER_ID)
|
||||||
|
human = next(b for b in blocks if b["label"] == "human")
|
||||||
|
|
||||||
|
assert replaced is True
|
||||||
|
assert "Name: Robert" in human["value"]
|
||||||
|
assert "Timezone: Europe/Rome" in human["value"]
|
||||||
|
|
||||||
|
deleted = await middleware.delete_core(USER_ID, "human")
|
||||||
|
assert deleted is True
|
||||||
|
assert await middleware.get_core_block(USER_ID, "human") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_archival_and_recall_search_helpers(db_session, user_with_key):
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
|
||||||
|
await middleware.insert_archival(USER_ID, "Project whitelist has release risk", source="assistant")
|
||||||
|
await middleware.store_episode(USER_ID, str(uuid.uuid4()), "How is whitelist?", "Whitelist is delayed")
|
||||||
|
|
||||||
|
arch = await middleware.search_archival(USER_ID, "whitelist", top_k=3)
|
||||||
|
rec = await middleware.search_recall(USER_ID, "delayed", top_k=3)
|
||||||
|
|
||||||
|
assert any("whitelist" in item.lower() for item in arch)
|
||||||
|
assert any("delayed" in item.lower() for item in rec)
|
||||||
|
|
||||||
|
|
||||||
# ── End-to-end WS: memory middleware is called during home_request ────────────
|
# ── End-to-end WS: memory middleware is called during home_request ────────────
|
||||||
|
|
||||||
def test_home_request_calls_memory_middleware(client):
|
def test_home_request_calls_memory_middleware(client):
|
||||||
@@ -240,21 +300,20 @@ def test_home_request_calls_memory_middleware(client):
|
|||||||
def __init__(self, db):
|
def __init__(self, db):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def enrich_context(self, user_id, message):
|
async def enrich_context(self, user_id, message, **kwargs):
|
||||||
enrich_calls.append((user_id, message))
|
enrich_calls.append((user_id, message))
|
||||||
return {"core_memory": {"tz": "UTC"}}
|
return {"core_memory": {"tz": "UTC"}}
|
||||||
|
|
||||||
async def store_episode(self, user_id, session_id, message, response):
|
async def store_episode(self, user_id, session_id, message, response, **kwargs):
|
||||||
store_calls.append((user_id, session_id, message, response))
|
store_calls.append((user_id, session_id, message, response))
|
||||||
|
|
||||||
token = make_jwt("power", user_id=USER_ID)
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
async def _mock_stream(user_id, message, context, db_session_factory=None):
|
async def _mock_stream(user_id, message, context):
|
||||||
# Verify memory context was injected
|
# Verify memory context was injected
|
||||||
assert context.get("core_memory") == {"tz": "UTC"}
|
assert context.get("core_memory") == {"tz": "UTC"}
|
||||||
yield ("token", "Done")
|
yield "token", "Done"
|
||||||
yield ("mutations", [])
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
|
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
|
||||||
|
|||||||
@@ -1,214 +1,82 @@
|
|||||||
"""Tests for app.core.output_formatter — HomeFormatter and FloatingFormatter."""
|
"""Tests for app.core.output_formatter.StreamFormatter."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
from app.core.output_formatter import StreamFormatter
|
||||||
from app.schemas import (
|
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
||||||
WsFloatingDomain,
|
|
||||||
WsStreamEnd,
|
|
||||||
WsStreamStart,
|
|
||||||
WsStreamText,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── helpers ───────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _stream(*events: tuple[str, object]):
|
async def _stream(*events: tuple[str, object]):
|
||||||
"""Async generator that yields (event_type, data) tuples."""
|
|
||||||
for event in events:
|
for event in events:
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
|
|
||||||
async def collect(formatter, event_stream):
|
async def _collect(formatter: StreamFormatter, event_stream):
|
||||||
frames = []
|
frames = []
|
||||||
async for frame in formatter.format(event_stream):
|
async for frame in formatter.format(event_stream):
|
||||||
frames.append(frame)
|
frames.append(frame)
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
# ── HomeFormatter ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_home_formatter_plain_text():
|
async def test_stream_formatter_text_stream() -> None:
|
||||||
req_id = "req-1"
|
formatter = StreamFormatter(request_id="req-1")
|
||||||
events = [
|
frames = await _collect(
|
||||||
("token", "Hello world"),
|
formatter,
|
||||||
("mutations", []),
|
_stream(("token", "Hello"), ("token", " world")),
|
||||||
]
|
)
|
||||||
formatter = HomeFormatter(request_id=req_id)
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
|
|
||||||
assert isinstance(frames[0], WsStreamStart)
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
assert frames[0].request_id == req_id
|
assert isinstance(frames[1], WsStreamText)
|
||||||
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
assert frames[1].chunk == "Hello"
|
||||||
assert any("Hello world" in f.chunk for f in text_frames)
|
assert isinstance(frames[2], WsStreamText)
|
||||||
|
assert frames[2].chunk == " world"
|
||||||
assert isinstance(frames[-1], WsStreamEnd)
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_home_formatter_entity_tags_passed_through():
|
async def test_stream_formatter_floating_domain_first() -> None:
|
||||||
"""Entity tags are streamed as-is — the frontend parses them."""
|
formatter = StreamFormatter(request_id="req-2")
|
||||||
req_id = "req-2"
|
frames = await _collect(
|
||||||
events = [
|
formatter,
|
||||||
("token", "Here is your project:\n<project>[abc-123]</project>\nAll good."),
|
_stream(
|
||||||
("mutations", []),
|
(
|
||||||
]
|
"floating_domain",
|
||||||
formatter = HomeFormatter(request_id=req_id)
|
{"type": "node", "id": "n-1", "section": None},
|
||||||
frames = await collect(formatter, _stream(*events))
|
),
|
||||||
|
|
||||||
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
|
||||||
assert "<project>[abc-123]</project>" in text
|
|
||||||
assert "Here is your project:" in text
|
|
||||||
assert "All good." in text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_home_formatter_multiple_tags_passed_through():
|
|
||||||
req_id = "req-3"
|
|
||||||
events = [
|
|
||||||
("token", "<project>[p1]</project>\nText\n<task>[t1,t2]</task>"),
|
|
||||||
("mutations", []),
|
|
||||||
]
|
|
||||||
formatter = HomeFormatter(request_id=req_id)
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
|
|
||||||
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
|
||||||
assert "<project>[p1]</project>" in text
|
|
||||||
assert "<task>[t1,t2]</task>" in text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_home_formatter_tool_end_ignored():
|
|
||||||
"""tool_end events are silently ignored by HomeFormatter."""
|
|
||||||
req_id = "req-4"
|
|
||||||
events = [
|
|
||||||
("tool_end", {"name": "task_agent", "result": "3 tasks"}),
|
|
||||||
("token", "No tags here."),
|
|
||||||
("mutations", []),
|
|
||||||
]
|
|
||||||
formatter = HomeFormatter(request_id=req_id)
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
|
|
||||||
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
|
||||||
assert text == "No tags here."
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_home_formatter_mutations_in_stream_end():
|
|
||||||
req_id = "req-5"
|
|
||||||
muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}]
|
|
||||||
events = [
|
|
||||||
("token", "Done"),
|
|
||||||
("mutations", muts),
|
|
||||||
]
|
|
||||||
formatter = HomeFormatter(request_id=req_id)
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
|
|
||||||
end_frame = frames[-1]
|
|
||||||
assert isinstance(end_frame, WsStreamEnd)
|
|
||||||
assert len(end_frame.mutations) == 1
|
|
||||||
assert end_frame.mutations[0]["action"] == "insert"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_home_formatter_frame_order():
|
|
||||||
"""stream_start is first, stream_end is last."""
|
|
||||||
req_id = "req-6"
|
|
||||||
formatter = HomeFormatter(request_id=req_id)
|
|
||||||
frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", [])))
|
|
||||||
assert isinstance(frames[0], WsStreamStart)
|
|
||||||
assert isinstance(frames[-1], WsStreamEnd)
|
|
||||||
|
|
||||||
|
|
||||||
# ── FloatingFormatter ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_floating_formatter_domain_from_tool_end():
|
|
||||||
req_id = "pop-1"
|
|
||||||
formatter = FloatingFormatter(request_id=req_id)
|
|
||||||
events = [
|
|
||||||
("tool_end", {"name": "task_agent", "result": "ok"}),
|
|
||||||
("token", "Hello"),
|
|
||||||
("mutations", []),
|
|
||||||
]
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
|
|
||||||
assert isinstance(frames[0], WsFloatingDomain)
|
|
||||||
assert frames[0].domain == "tasks"
|
|
||||||
assert frames[0].request_id == req_id
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_floating_formatter_text_only():
|
|
||||||
req_id = "pop-2"
|
|
||||||
formatter = FloatingFormatter(request_id=req_id)
|
|
||||||
events = [
|
|
||||||
("tool_end", {"name": "timeline_agent", "result": "done"}),
|
|
||||||
("token", "Summary"),
|
("token", "Summary"),
|
||||||
("mutations", []),
|
),
|
||||||
]
|
)
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
|
|
||||||
assert isinstance(frames[0], WsFloatingDomain)
|
assert isinstance(frames[0], WsFloatingDomain)
|
||||||
assert frames[0].domain == "timelines"
|
assert frames[0].domain.type == "node"
|
||||||
|
assert frames[0].domain.id == "n-1"
|
||||||
|
assert isinstance(frames[1], WsStreamStart)
|
||||||
|
assert isinstance(frames[2], WsStreamText)
|
||||||
|
assert frames[2].chunk == "Summary"
|
||||||
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_formatter_ignores_unknown_events() -> None:
|
||||||
|
formatter = StreamFormatter(request_id="req-3")
|
||||||
|
frames = await _collect(
|
||||||
|
formatter,
|
||||||
|
_stream(("tool_end", {"name": "x"}), ("token", "ok")),
|
||||||
|
)
|
||||||
|
|
||||||
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
||||||
assert len(text_frames) == 1
|
assert len(text_frames) == 1
|
||||||
assert text_frames[0].chunk == "Summary"
|
assert text_frames[0].chunk == "ok"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_floating_formatter_no_entity_tags():
|
async def test_stream_formatter_empty_stream_still_brackets() -> None:
|
||||||
"""FloatingFormatter never emits entity tag blocks."""
|
formatter = StreamFormatter(request_id="req-4")
|
||||||
req_id = "pop-3"
|
frames = await _collect(formatter, _stream())
|
||||||
formatter = FloatingFormatter(request_id=req_id)
|
|
||||||
events = [
|
|
||||||
("tool_end", {"name": "note_agent", "result": "data"}),
|
|
||||||
("token", "some text"),
|
|
||||||
("mutations", []),
|
|
||||||
]
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
# Only expected frame types
|
|
||||||
for f in frames:
|
|
||||||
assert isinstance(f, (WsFloatingDomain, WsStreamStart, WsStreamText, WsStreamEnd))
|
|
||||||
|
|
||||||
|
assert len(frames) == 2
|
||||||
@pytest.mark.asyncio
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
async def test_floating_formatter_end_frame():
|
assert isinstance(frames[1], WsStreamEnd)
|
||||||
req_id = "pop-4"
|
|
||||||
formatter = FloatingFormatter(request_id=req_id)
|
|
||||||
events = [
|
|
||||||
("tool_end", {"name": "project_agent", "result": "ok"}),
|
|
||||||
("token", "Done"),
|
|
||||||
("mutations", []),
|
|
||||||
]
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
assert isinstance(frames[-1], WsStreamEnd)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_floating_formatter_default_domain_on_early_token():
|
|
||||||
"""When the first event is a token (no tool_end yet), default to 'tasks'."""
|
|
||||||
req_id = "pop-5"
|
|
||||||
formatter = FloatingFormatter(request_id=req_id)
|
|
||||||
events = [("token", "hi"), ("mutations", [])]
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
assert isinstance(frames[0], WsFloatingDomain)
|
|
||||||
assert frames[0].domain == "tasks"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_floating_formatter_mutations_in_stream_end():
|
|
||||||
req_id = "pop-6"
|
|
||||||
muts = [{"action": "update", "table": "tasks", "data": {"id": "t2"}}]
|
|
||||||
events = [
|
|
||||||
("token", "Updated"),
|
|
||||||
("mutations", muts),
|
|
||||||
]
|
|
||||||
formatter = FloatingFormatter(request_id=req_id)
|
|
||||||
frames = await collect(formatter, _stream(*events))
|
|
||||||
|
|
||||||
end_frame = frames[-1]
|
|
||||||
assert isinstance(end_frame, WsStreamEnd)
|
|
||||||
assert len(end_frame.mutations) == 1
|
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ class TestPluginRegistry:
|
|||||||
async def test_list_filter_by_query(
|
async def test_list_filter_by_query(
|
||||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
) -> None:
|
) -> None:
|
||||||
result = await reg.list_plugins(db_session, query="time tracker")
|
result = await reg.list_plugins(db_session, query="time")
|
||||||
assert result.total == 1
|
assert result.total == 1
|
||||||
assert result.plugins[0].id == "plugin-time-tracker"
|
assert result.plugins[0].id == "plugin-time-tracker"
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import pytest
|
|||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
|
WsDomain,
|
||||||
WsFrameType,
|
WsFrameType,
|
||||||
WsHomeRequest,
|
WsHomeRequest,
|
||||||
WsFloatingDomain,
|
WsFloatingDomain,
|
||||||
@@ -178,23 +179,15 @@ def test_stream_text_deserializes():
|
|||||||
def test_stream_end_defaults():
|
def test_stream_end_defaults():
|
||||||
frame = WsStreamEnd(request_id="r1")
|
frame = WsStreamEnd(request_id="r1")
|
||||||
assert frame.type == WsFrameType.stream_end
|
assert frame.type == WsFrameType.stream_end
|
||||||
assert frame.mutations == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_end_with_mutations():
|
|
||||||
mutations = [{"action": "create", "table": "tasks", "data": {"title": "New task"}}]
|
|
||||||
frame = WsStreamEnd(request_id="r1", mutations=mutations)
|
|
||||||
assert len(frame.mutations) == 1
|
|
||||||
assert frame.mutations[0]["action"] == "create"
|
|
||||||
|
|
||||||
|
|
||||||
def test_stream_end_serializes():
|
def test_stream_end_serializes():
|
||||||
data = WsStreamEnd(request_id="r2").model_dump()
|
data = WsStreamEnd(request_id="r2").model_dump()
|
||||||
assert data == {"type": "stream_end", "request_id": "r2", "mutations": []}
|
assert data == {"type": "stream_end", "request_id": "r2"}
|
||||||
|
|
||||||
|
|
||||||
def test_stream_end_deserializes():
|
def test_stream_end_deserializes():
|
||||||
raw = {"type": "stream_end", "request_id": "r3", "mutations": []}
|
raw = {"type": "stream_end", "request_id": "r3"}
|
||||||
frame = WsStreamEnd.model_validate(raw)
|
frame = WsStreamEnd.model_validate(raw)
|
||||||
assert frame.request_id == "r3"
|
assert frame.request_id == "r3"
|
||||||
|
|
||||||
@@ -203,28 +196,47 @@ def test_stream_end_deserializes():
|
|||||||
|
|
||||||
|
|
||||||
def test_floating_domain_tasks():
|
def test_floating_domain_tasks():
|
||||||
frame = WsFloatingDomain(request_id="r1", domain="tasks")
|
frame = WsFloatingDomain(request_id="r1", domain=WsDomain(type="task"))
|
||||||
assert frame.type == WsFrameType.floating_domain
|
assert frame.type == WsFrameType.floating_domain
|
||||||
assert frame.domain == "tasks"
|
assert frame.domain.type == "task"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("domain", ["tasks", "timelines", "notes", "projects"])
|
def test_floating_domain_valid_domains():
|
||||||
def test_floating_domain_valid_domains(domain: str):
|
frame = WsFloatingDomain(
|
||||||
frame = WsFloatingDomain(request_id="r1", domain=domain) # type: ignore[arg-type]
|
request_id="r1",
|
||||||
assert frame.domain == domain
|
domain=WsDomain(type="project", id="213213-312321-312312-421321", section="task"),
|
||||||
|
)
|
||||||
|
assert frame.domain.type == "project"
|
||||||
|
assert frame.domain.id == "213213-312321-312312-421321"
|
||||||
|
assert frame.domain.section == "task"
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_invalid():
|
def test_floating_domain_object_valid():
|
||||||
with pytest.raises(ValidationError):
|
frame = WsFloatingDomain(
|
||||||
WsFloatingDomain(request_id="r1", domain="invalid") # type: ignore[arg-type]
|
request_id="r1",
|
||||||
|
domain=WsDomain(type="project", id="p1", section="task"),
|
||||||
|
)
|
||||||
|
assert frame.domain.type == "project"
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_serializes():
|
def test_floating_domain_serializes():
|
||||||
d = WsFloatingDomain(request_id="r1", domain="notes").model_dump()
|
d = WsFloatingDomain(
|
||||||
assert d == {"type": "floating_domain", "request_id": "r1", "domain": "notes"}
|
request_id="r1",
|
||||||
|
domain=WsDomain(type="timeline"),
|
||||||
|
).model_dump()
|
||||||
|
assert d == {
|
||||||
|
"type": "floating_domain",
|
||||||
|
"request_id": "r1",
|
||||||
|
"domain": {"type": "timeline", "id": None, "section": None},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_deserializes():
|
def test_floating_domain_deserializes():
|
||||||
raw = {"type": "floating_domain", "request_id": "r1", "domain": "projects"}
|
raw = {
|
||||||
|
"type": "floating_domain",
|
||||||
|
"request_id": "r1",
|
||||||
|
"domain": {"type": "node", "id": "n-1", "section": None},
|
||||||
|
}
|
||||||
frame = WsFloatingDomain.model_validate(raw)
|
frame = WsFloatingDomain.model_validate(raw)
|
||||||
assert frame.domain == "projects"
|
assert frame.domain.type == "node"
|
||||||
|
assert frame.domain.id == "n-1"
|
||||||
|
|||||||
@@ -45,15 +45,13 @@ def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
|
|||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
async def _mock_home_stream(user_id, message, context, db_session_factory=None):
|
async def _mock_home_stream(user_id, message, context):
|
||||||
yield "token", "Here are your tasks:\n<task>[t1,t2]</task>"
|
yield "token", "Hello"
|
||||||
yield "mutations", []
|
|
||||||
|
|
||||||
|
|
||||||
async def _mock_floating_stream(user_id, message, context, scope=None, db_session_factory=None):
|
async def _mock_floating_stream(user_id, message, context):
|
||||||
yield "tool_end", {"name": "task_agent", "result": "ok"}
|
yield "floating_domain", {"type": "task", "id": None, "section": None}
|
||||||
yield "token", "Here is a summary"
|
yield "token", "Here is a summary"
|
||||||
yield "mutations", []
|
|
||||||
|
|
||||||
|
|
||||||
# ── tests ─────────────────────────────────────────────────────────────────────
|
# ── tests ─────────────────────────────────────────────────────────────────────
|
||||||
@@ -104,7 +102,7 @@ def test_floating_request_produces_domain_frame(client):
|
|||||||
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
|
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
|
||||||
|
|
||||||
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
|
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
|
||||||
assert domain_frame["domain"] == "tasks"
|
assert domain_frame["domain"]["type"] == "task"
|
||||||
assert domain_frame["request_id"] == "p1"
|
assert domain_frame["request_id"] == "p1"
|
||||||
|
|
||||||
|
|
||||||
@@ -113,9 +111,8 @@ def test_home_request_request_id_propagated(client):
|
|||||||
token = make_jwt("power", user_id=USER_ID)
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
req_id = "my-unique-req-id"
|
req_id = "my-unique-req-id"
|
||||||
|
|
||||||
async def _stream(user_id, message, context, db_session_factory=None):
|
async def _stream(user_id, message, context):
|
||||||
yield "token", "ok"
|
yield "token", "ok"
|
||||||
yield "mutations", []
|
|
||||||
|
|
||||||
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
|
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
|
||||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
|||||||
Reference in New Issue
Block a user