21 Commits

Author SHA1 Message Date
Roberto Musso
f340d0fa3e Fix dev tier: default to power when no subscription exists
The tier is resolved live from the subscriptions table in get_current_user.
Previously fell back to 'free' unconditionally, hitting the 5 runs/day
limit immediately in dev. Now falls back to 'power' (unlimited) when
ENV=dev and no subscription row exists.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-20 12:32:36 +01:00
Roberto Musso
edc53cb6eb Default to power tier (unlimited) in dev when no subscription exists
Users without a subscription row in dev get power tier so rate limits
and quota checks don't block local development. In prod the fallback
remains free tier as before.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-20 12:12:43 +01:00
Roberto Musso
725cece5c1 Add run_context to agent tool calls for FE run logging
- AgentTriggerRequest accepts optional agent_id (FE's stable electron-store UUID)
- _make_agent_executor injects run_context into every tool_call frame
  so Electron can attribute actions to the correct agent run
- run_local_agent accepts run_context and sends a run_complete WS frame
  when the run finishes so the FE can close the run record
- trigger_agent_run builds run_context with run_id=run_log.id and the
  stable agent_id, passes it through to run_local_agent

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-20 09:46:17 +01:00
Roberto Musso
297e20ce8d Fix 422 on agent trigger: accept plural data type names
AgentTriggerRequest.what_to_extract now accepts list[str] instead of
strict Literal values. _to_data_types normalises all FE variants
(tasks/task, notes/note, timelines/timeline/timelineEvents,
projects/project) to the canonical plural form the runner expects,
with deduplication.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-18 00:04:29 +01:00
Roberto Musso
5a03bd1cfb Clean up agent catalog and improve extraction agent prompts
- Remove unused config_schema from AgentCatalogItem (schema + route)
- Fix agent_setup system prompt: add extraction agent base behaviour
  context so journey LLM knows what is already handled and focuses on
  field mappings only; remove redundant data-types question (already
  known from user selection); derive data types list dynamically
- Rewrite processing base prompt to use actual tool names
  (list_tasks, update_task, add_task_comment, list_notes, update_note,
  list_timelines, update_timeline, list_all_projects, create_project)
  and enforce update-first strategy before falling back to creation

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-17 23:52:54 +01:00
Roberto Musso
87b7a1c6c9 fix journey setup: honor FE session_id, seed LLM history, and force template on max turns
- Use session_id from the FE frame so replies match the listener key
- Seed conversation with a user message for LLM provider compatibility
- On max turns, nudge the LLM and immediately re-invoke to force
  prompt_template generation instead of deferring to next message
- Fix display_message extraction to safely check for template markers

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-17 16:25:53 +01:00
Roberto Musso
826f64d6bb refactor local directory agent to two-phase LLM-with-tools architecture
Replace the single-pass FE-driven agent_run/agent_data flow with a
BE-orchestrated two-phase execution using LangChain tool-calling:
- Phase 1 (Triage): explores directory via new filesystem tools, matches
  files to existing projects using PROJECT_TOOLS
- Phase 2 (Processing): reads files and performs CRUD per project group
  with clean LLM context windows

Key changes:
- Add filesystem_agent.py with list_directory, read_file_content,
  get_file_metadata tools using execute_on_client()
- Move setup journey from REST to WebSocket (journey_start/message frames)
- Add batch_runs_per_day billing limit and enforce in /trigger
- Remove deprecated agent_data/agent_complete frame handlers and queues

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-17 08:50:46 +01:00
5faa6b1d7c refactor agents to client-owned config flow 2026-03-16 22:35:46 +01:00
02a9684cd6 scope episodic memory enrichment by session_id 2026-03-16 00:33:11 +01:00
fae9efee0d removed old plan files 2026-03-13 16:58:43 +01:00
30b062dd4a fix floating stream empty responses with sanitizer-safe fallbacks 2026-03-13 16:57:30 +01:00
2a0331d7ce refactor floating_domain to structured object-only payload 2026-03-13 16:09:24 +01:00
13fd8677c1 fix: normalize home task/timeline responses to tag-only lines 2026-03-13 12:16:58 +01:00
9bd629cb59 chore: add interaction tracing and remove personal fields from logs 2026-03-13 10:23:47 +01:00
9c97702daa feat: add letta-style memory tools with request/user debug tracing 2026-03-13 09:34:23 +01:00
a1e364c9c0 refactor: switch to single-agent deep runner and add mock memory/tool tests 2026-03-13 08:20:42 +01:00
5b55f1292a make a single agent 2026-03-13 07:42:36 +01:00
5bc9ea6cd6 fix: make planner schema copilot-compatible and silence usage warning 2026-03-12 23:17:31 +01:00
f7404b6f66 refactor: move memory updates from synthesizer to orchestrator node 2026-03-12 23:03:38 +01:00
d667e43c73 refactor: use native LangGraph streaming and enforce structured summary on workers 2026-03-12 22:50:32 +01:00
fe085a7951 feat: migrate chat orchestration to deep langgraph workers 2026-03-12 22:25:36 +01:00
31 changed files with 2869 additions and 2029 deletions

View File

@@ -0,0 +1,92 @@
"""Deprecate backend agent config tables.
The Electron client is now the source of truth for agent configuration
(directory, extract targets, batch interval, custom prompt). Backend keeps
billing checks and trigger/run logs only.
Revision ID: 9a1f2d0b6c7e
Revises: 818478c251dc
Create Date: 2026-03-16
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "9a1f2d0b6c7e"
down_revision: Union[str, None] = "818478c251dc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
bind = op.get_bind()
inspector = sa.inspect(bind)
existing = set(inspector.get_table_names())
if "cloud_agent_configs" in existing:
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
op.drop_table("cloud_agent_configs")
if "local_agent_configs" in existing:
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
op.drop_table("local_agent_configs")
def downgrade() -> None:
op.create_table(
"local_agent_configs",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("device_id", sa.String(255), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
op.execute(
"""
DO $$ BEGIN
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
"""
)
op.create_table(
"cloud_agent_configs",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column(
"provider",
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
nullable=False,
),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
sa.Column("filter_config", sa.JSON, nullable=True),
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])

View File

@@ -1,5 +1,5 @@
"""Agent tool modules — imported by deep_agent.py to build sub-agent graphs."""
"""Expose tool modules used by deep orchestrator-worker graphs."""
from app.agents import timeline_agent, note_agent, project_agent, task_agent
from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent
__all__ = ["timeline_agent", "note_agent", "project_agent", "task_agent"]
__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"]

View File

@@ -0,0 +1,85 @@
"""Filesystem agent — tools for reading local directories and files on Electron.
These tools delegate to the Electron client via ``execute_on_client()`` using
the same WS tool-call round-trip pattern as CRUD tools. The Electron app
handles actual disk I/O and responds with ``tool_result`` frames.
"""
from __future__ import annotations
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
@tool
async def list_directory(path: str) -> str:
"""List files and folders in a local directory on the user's device.
Returns a formatted listing of entries with name, type (file/directory),
and full path.
"""
result = await execute_on_client(
action="list_directory",
data={"path": path},
)
entries: list[dict[str, Any]] = result.get("entries", [])
if not entries:
return f"Directory '{path}' is empty or does not exist."
lines: list[str] = []
for entry in entries:
entry_type = entry.get("type", "unknown")
entry_name = entry.get("name", "")
entry_path = entry.get("path", "")
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
@tool
async def read_file_content(path: str) -> str:
"""Read the text content of a local file on the user's device.
Returns the file content as a string. Large files may be truncated
by the Electron client.
"""
result = await execute_on_client(
action="read_file_content",
data={"path": path},
)
content: str = result.get("content", "")
if not content:
return f"File '{path}' is empty or could not be read."
return content
@tool
async def get_file_metadata(path: str) -> str:
"""Get metadata for a local file: size, creation date, modification date, extension.
Returns a formatted summary of the file's metadata.
"""
result = await execute_on_client(
action="get_file_metadata",
data={"path": path},
)
size = result.get("size", "unknown")
created = result.get("createdAt", "unknown")
modified = result.get("modifiedAt", "unknown")
extension = result.get("extension", "unknown")
name = result.get("name", path)
return (
f"File: {name}\n"
f" Extension: {extension}\n"
f" Size: {size} bytes\n"
f" Created: {created}\n"
f" Modified: {modified}"
)
FILESYSTEM_TOOLS: list[Any] = [
list_directory,
read_file_content,
get_file_metadata,
]

View File

@@ -1,7 +1,8 @@
"""Note agent — tool definitions for Markdown note CRUD."""
"""Note agent — Markdown note management (list, get, create, update, delete)."""
from __future__ import annotations
import re
from typing import Any
from langchain_core.tools import tool
@@ -9,14 +10,38 @@ from langchain_core.tools import tool
from app.core.llm import embed
from app.core.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
NOTE_SYSTEM_PROMPT = (
"You are a note-taking assistant. You help users create, retrieve, update,\n"
"and delete Markdown notes in their workspace.\n\n"
"Rules:\n"
" - content is always Markdown; preserve formatting when updating\n"
" - project_id is optional; link a note to a project when mentioned\n"
" - When updating, call get_note first if you need to read existing content\n"
" before appending or replacing sections\n"
" - list_notes without project_id returns all notes; scope with project_id\n"
" when the user is working within a specific project\n"
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
" - Do not fabricate note content — reflect what the user provides or what\n"
" is already in the note (retrieved via get_note)."
)
@tool
async def list_notes(project_id: str = "") -> str:
"""List notes, optionally scoped to a project by project_id."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="notes",
filters={"projectId": project_id or None},
filters={"projectId": normalized_project_id or None},
)
rows = result.get("rows", [])
if not rows:
@@ -105,4 +130,10 @@ async def delete_note(note_id: str) -> str:
return f"Note {note_id} deleted."
NOTE_TOOLS: list[Any] = [
list_notes,
get_note,
create_note,
update_note,
delete_note,
]

View File

@@ -1,4 +1,4 @@
"""Project agent — tool definitions for project lifecycle CRUD."""
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
from __future__ import annotations
@@ -8,6 +8,22 @@ from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
PROJECT_SYSTEM_PROMPT = (
"You are a project management assistant. You help users create, find,\n"
"update, and archive projects in their workspace.\n\n"
"Rules:\n"
" - status must be one of: active, archived\n"
" - client_id is optional; link to a client only when explicitly mentioned\n"
" - ai_summary is populated only when the user asks for a project summary;\n"
" derive it from context data — do not fabricate content\n"
" - Use list_projects for scoped queries; list_all_projects only when the\n"
" user wants a complete cross-client view including archived projects\n"
" - get_project requires a project UUID; resolve the ID first by calling\n"
" list_projects if you only have a project name\n"
" - Prefer archiving (update_project status=archived) over deletion;\n"
" only call delete_project when the user explicitly confirms deletion."
)
@tool
async def list_projects(
@@ -117,4 +133,11 @@ async def delete_project(project_id: str) -> str:
return f"Project {project_id} permanently deleted."
PROJECT_TOOLS: list[Any] = [
list_projects,
list_all_projects,
get_project,
create_project,
update_project,
delete_project,
]

View File

@@ -1,14 +1,40 @@
"""Task agent — tool definitions for task and task comment CRUD."""
"""Task agent — full CRUD for tasks and task comments."""
from __future__ import annotations
from datetime import datetime, timezone
import re
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
TASK_SYSTEM_PROMPT = (
"You are a task management assistant for a project workspace.\n"
"You create, update, list, and track tasks and their comments.\n\n"
"Rules:\n"
" - status must be one of: todo, in_progress, done\n"
" - priority must be one of: high, medium, low\n"
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
" - project_id is optional; link to a project when the user mentions one\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
" did not explicitly request; 0 otherwise\n"
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
" - Use list_tasks_due_today for 'what's due today' queries\n"
" - For update_task, use -1 for integer fields you do not want to change\n"
" - Always confirm the action in plain, user-friendly language."
)
# ── Task tools ────────────────────────────────────────────────────────
@@ -22,11 +48,12 @@ async def list_tasks(
) -> str:
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
a search string, or an order_by field name (dueDate|priority|createdAt)."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="tasks",
filters={
"projectId": project_id or None,
"projectId": normalized_project_id or None,
"status": status or None,
"search": search or None,
"orderBy": order_by or None,
@@ -188,8 +215,12 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
table="taskComments",
data={"taskId": task_id, "author": author, "content": content},
)
row = result["row"]
return f"Comment added by {row['author']} on task {row['taskId']} (comment id: {row['id']})."
row = result.get("row", {})
row_author = row.get("author", author)
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
row_task_id = row.get("taskId") or row.get("task_id") or task_id
row_comment_id = row.get("id", "unknown")
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
@tool
@@ -199,4 +230,16 @@ async def delete_task_comment(comment_id: str) -> str:
return f"Comment {comment_id} deleted."
# ── Agent ─────────────────────────────────────────────────────────────
TASK_TOOLS: list[Any] = [
list_tasks,
create_task,
update_task,
delete_task,
list_tasks_due_today,
list_task_comments,
add_task_comment,
delete_task_comment,
]

View File

@@ -1,21 +1,45 @@
"""Timeline agent — tool definitions for project milestone CRUD."""
"""Timeline agent — project milestone management (list, create, update, delete)."""
from __future__ import annotations
import re
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
TIMELINE_SYSTEM_PROMPT = (
"You are a project timeline assistant. Timelines are milestone dates that\n"
"track progress on a project — they are not calendar events.\n\n"
"Rules:\n"
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
" - is_approved: 0 until the user explicitly confirms; then 1\n"
" - For update_timeline, use -1 for integer fields you do not want to change\n"
" - Listing without a project_id returns all timelines across projects\n"
" - Always echo the title and formatted date in your confirmation."
)
@tool
async def list_timelines(project_id: str = "") -> str:
"""List timelines. Provide project_id to scope to a specific project."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="timelines",
filters={"projectId": project_id or None},
filters={"projectId": normalized_project_id or None},
)
rows = result.get("rows", [])
if not rows:
@@ -89,4 +113,9 @@ async def delete_timeline(timeline_id: str) -> str:
return f"Timeline {timeline_id} deleted."
TIMELINE_TOOLS: list[Any] = [
list_timelines,
create_timeline,
update_timeline,
delete_timeline,
]

View File

@@ -55,12 +55,15 @@ async def get_current_user(
raise credentials_exc
# Live tier lookup — subscription row is the authoritative source.
# In dev, fall back to 'power' (unlimited) so quota limits don't
# block local development when no Stripe subscription exists.
from app.models import Subscription, User # noqa: PLC0415
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
tier: str = result.scalar_one_or_none() or "free"
default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier
# Fetch name/surname from user row.
user_result = await db.execute(

View File

@@ -1,54 +1,40 @@
"""Chatbot Journey endpoints — guided conversation to build an agent prompt_template.
"""Chatbot Journey — WS-based guided conversation to build an agent prompt_template.
Endpoints:
POST /agents/journey/start — start a new journey session
POST /agents/journey/message — continue the conversation
Sessions are stored in-memory with a 30-minute TTL. Stale entries are
cleaned up lazily on access. Upgrade to Redis for multi-instance deployments.
The journey is driven entirely through WebSocket frames (no REST endpoints).
The device WS handler dispatches ``journey_start`` and ``journey_message``
frames to the functions exported here.
Journey flow:
1. Client sends ``{ agent_type, agent_id? }`` to ``/start``.
2. Server creates a session, calls the LLM with a contextual system prompt,
and returns the first question.
3. Client sends follow-up messages to ``/message``.
4. After 3-5 turns the LLM wraps up by emitting a ``prompt_template`` block
delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
5. Server parses the block, sets ``done=True``, and returns the template.
The ``prompt_template`` from the final response is meant to be stored in
``LocalAgentConfig.prompt_template`` or ``CloudAgentConfig.prompt_template``
by the Electron client (via the agent CRUD endpoints).
1. FE sends ``journey_start`` frame with basic agent config (directory,
data_types, schedule).
2. Server creates an in-memory session, sets up a WS executor so the
setup LLM can use file-system tools, does a first directory scrape,
and sends back a ``journey_reply`` with the first question.
3. FE sends ``journey_message`` frames for each user reply.
4. Server appends the user message, calls the LLM (which may read files
via tools), and sends back a ``journey_reply``.
5. After 3-5 turns the LLM wraps up by emitting a ``prompt_template``
block delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
6. Server parses the block, sends ``journey_reply`` with ``done=True``
and the template. FE stores it locally.
"""
from __future__ import annotations
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, status
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from app.api.deps import get_current_user
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
from app.core.llm import get_llm
from app.db import get_session
from app.models import CloudAgentConfig, LocalAgentConfig
from app.schemas import (
JourneyMessageRequest,
JourneyResponse,
JourneyStartRequest,
UserProfile,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/agents/journey", tags=["agents"])
# ── Session TTL ───────────────────────────────────────────────────────────
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
@@ -59,16 +45,21 @@ _TEMPLATE_END = "PROMPT_TEMPLATE_END"
# Maximum number of conversation turns before the LLM is nudged to wrap up.
_MAX_TURNS: int = 5
# Max tool-calling steps per LLM invocation.
_MAX_TOOL_STEPS: int = 6
# ── In-memory session store ───────────────────────────────────────────────
@dataclass
class _JourneySession:
class JourneySession:
session_id: str
user_id: str
agent_type: str # "local" | "cloud"
directory: str
data_types: list[str]
history: list[dict[str, Any]] = field(default_factory=list)
system_prompt: str = ""
created_at: float = field(default_factory=time.monotonic)
def is_expired(self) -> bool:
@@ -76,67 +67,77 @@ class _JourneySession:
# session_id → session
_sessions: dict[str, _JourneySession] = {}
_sessions: dict[str, JourneySession] = {}
def _get_session(session_id: str, user_id: str) -> _JourneySession:
"""Retrieve session; raise 404 on missing, expired, or wrong owner."""
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
"""Retrieve session; return None on missing, expired, or wrong owner."""
s = _sessions.get(session_id)
if s is None or s.is_expired():
_sessions.pop(session_id, None)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
return None
if s.user_id != user_id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
return None
return s
# ── System prompt builder ─────────────────────────────────────────────────
_LOCAL_PREAMBLE = """\
What kind of files are in the directories you want to monitor? \
(for example: emails saved as .eml, documents in .pdf or .txt, markdown notes, etc.)"""
_CLOUD_PREAMBLE = """\
What kind of emails or messages should I look for? \
(for example: client communications, invoices, meeting notes, project updates, etc.)"""
_SYSTEM_PROMPT_TEMPLATE = """\
You are a friendly assistant helping a freelancer configure a data-extraction agent.
Your job is to understand exactly what data the user wants to extract from their {source_description} \
and produce a detailed prompt_template that a separate AI will use as its instruction set.
Your job is to understand exactly what data the user wants to extract from their
local directory and produce a detailed prompt_template that a separate AI will use
as its instruction set.
Ask concise, focused questions one at a time. Cover these topics (not necessarily in this order):
1. The type and format of the source content.
2. Which data types to extract: tasks, notes, timelines, and/or projects.
3. How fields should be mapped (e.g. email subject → task title).
4. Priority or status rules (e.g. "urgent" keyword → high priority).
5. Any special handling, date extraction, or exclusions.
The extraction agent already has this base behaviour built in:
- Reads each file using file-system tools.
- Creates records (tasks, notes, timelines, projects) via CRUD tools.
- Sets isAiSuggested=1 and isApproved=0 on every record.
- Only extracts data explicitly present in the files — it never invents information.
The user's custom prompt is appended AFTER this base behaviour, so focus on
what to look for and how to map it — not on the general extraction mechanics.
After 3-5 questions (when you have enough information), output the final prompt_template between \
these exact markers on their own lines:
You have access to file-system tools to explore the user's directory:
- list_directory: to see folder structure
- read_file_content: to peek at file contents
- get_file_metadata: to check file info
The user's configured directory is: {directory}
Target data types: {data_types}
Start by exploring the directory to understand its structure. Then ask concise,
focused questions one at a time. Cover these topics (not necessarily in this order):
1. The type and format of the source content (confirmed by your exploration).
2. How fields should be mapped (e.g. filename → task title).
3. Priority or status rules (e.g. "urgent" keyword → high priority).
4. Any special handling, date extraction, or exclusions.
After 3-5 questions (when you have enough information), output the final prompt_template
between these exact markers on their own lines:
{template_start}
<the complete extraction prompt here>
{template_end}
The prompt_template must be a self-contained instruction for an AI that receives a document/email/message \
and must return a JSON array of records in this shape:
[{{ "table": "<tasks|notes|timelines|projects>", "data": {{ <field: value> }} }}, ...]
The prompt_template must be a self-contained instruction for an AI that reads files
and must perform CRUD operations using tools to create records. It should specify:
- What entity types to create (tasks, notes, timelines, projects).
- How to map file content to record fields (camelCase: title, status, priority,
dueDate, projectId, content, etc.).
- That isAiSuggested must be set to 1 and isApproved to 0 on every record.
- Concrete examples of mappings based on what you discovered in the directory.
Rules for the generated template:
- Be explicit about field names (camelCase: title, status, priority, dueDate, projectId, content, etc.).
- Include concrete examples of mappings.
- Mention that Electron adds id/createdAt/updatedAt automatically.
- Set isAiSuggested: true and isApproved: false on every record.
{existing_section}\
Do not ask more than {max_turns} questions total. Start with your first question now.\
Do not ask more than {max_turns} questions total. Begin by exploring the directory,
then ask your first question.\
"""
def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
source_description = (
"files in local directories" if agent_type == "local" else "emails and messages from cloud providers"
)
def _build_system_prompt(
directory: str,
data_types: list[str],
existing_template: str | None = None,
) -> str:
existing_section = (
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
f"---\n{existing_template}\n---\n"
@@ -144,7 +145,8 @@ def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
else ""
)
return _SYSTEM_PROMPT_TEMPLATE.format(
source_description=source_description,
directory=directory,
data_types=", ".join(data_types),
template_start=_TEMPLATE_START,
template_end=_TEMPLATE_END,
existing_section=existing_section,
@@ -152,10 +154,6 @@ def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
)
def _first_question(agent_type: str) -> str:
return _LOCAL_PREAMBLE if agent_type == "local" else _CLOUD_PREAMBLE
# ── Template extraction ───────────────────────────────────────────────────
@@ -168,11 +166,37 @@ def _extract_template(text: str) -> str | None:
return text[start_idx:end_idx].strip() or None
# ── LLM call ─────────────────────────────────────────────────────────────
# ── LLM call with tool support ───────────────────────────────────────────
async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
"""Build LangChain messages from history and invoke the LLM."""
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
async def _call_llm_with_tools(
system_prompt: str,
history: list[dict[str, Any]],
tools: list[Any],
) -> str:
"""Build LangChain messages from history and invoke the LLM with tools.
Handles tool-calling loops: if the LLM calls tools, execute them and
continue until a final text response is produced.
"""
messages: list[Any] = [SystemMessage(content=system_prompt)]
for turn in history:
if turn["role"] == "user":
@@ -181,137 +205,194 @@ async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
messages.append(AIMessage(content=turn["content"]))
llm = get_llm(model=None, temperature=0.4)
response = await llm.ainvoke(messages)
return response.content # type: ignore[return-value]
llm_with_tools = llm.bind_tools(tools)
tool_map = {tool_def.name: tool_def for tool_def in tools}
for _ in range(_MAX_TOOL_STEPS):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return _as_text(response.content)
for call in response.tool_calls:
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"agent_setup: journey tool_call name=%s args=%s",
call_name,
json.dumps(call_args, ensure_ascii=True)[:500],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"agent_setup: journey tool_result name=%s output=%s",
call_name,
str(tool_output)[:800],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
# Fallback: exceeded max steps.
final = await llm.ainvoke(messages)
return _as_text(final.content)
# ── Existing-config loader ────────────────────────────────────────────────
# ── Journey handlers (called from device_ws.py) ──────────────────────────
async def _load_existing_template(
agent_id: str,
async def handle_journey_start(
user_id: str,
db: AsyncSession,
) -> str | None:
"""Return the prompt_template of an existing agent config, or None."""
# Try local first, then cloud.
local_result = await db.execute(
select(LocalAgentConfig).where(
LocalAgentConfig.id == agent_id,
LocalAgentConfig.user_id == user_id,
)
)
local = local_result.scalar_one_or_none()
if local is not None:
return local.prompt_template
frame: dict[str, Any],
) -> dict[str, Any]:
"""Handle a ``journey_start`` WS frame.
cloud_result = await db.execute(
select(CloudAgentConfig).where(
CloudAgentConfig.id == agent_id,
CloudAgentConfig.user_id == user_id,
)
)
cloud = cloud_result.scalar_one_or_none()
return cloud.prompt_template if cloud is not None else None
# ── Routes ────────────────────────────────────────────────────────────────
@router.post("/start", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
async def start_journey(
body: JourneyStartRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> JourneyResponse:
"""Start a new Chatbot Journey session.
If ``agent_id`` is provided the session is pre-seeded with the existing
agent's ``prompt_template`` so the user can refine it.
Creates a session, runs the setup LLM with directory exploration,
and returns the ``journey_reply`` payload.
"""
# Load existing template (may be None).
existing_template: str | None = None
if body.agent_id:
existing_template = await _load_existing_template(body.agent_id, current_user.id, db)
# If agent_id was given but not found, proceed without seeding (don't 404 —
# the user may be starting a fresh journey for a not-yet-persisted config).
agent_type = frame.get("agent_type", "local")
directory = frame.get("directory", "")
data_types = frame.get("data_types", [])
existing_template = frame.get("existing_template")
system_prompt = _build_system_prompt(body.agent_type, existing_template)
first_question = _first_question(body.agent_type)
# Use the session_id provided by the FE so the reply matches the
# listener key; fall back to a generated one if absent.
session_id = frame.get("session_id") or str(uuid.uuid4())
system_prompt = _build_system_prompt(directory, data_types, existing_template)
session_id = str(uuid.uuid4())
session = _JourneySession(
session = JourneySession(
session_id=session_id,
user_id=current_user.id,
agent_type=body.agent_type,
# Seed history with the AI's first question so it stays consistent.
history=[{"role": "assistant", "content": first_question}],
user_id=user_id,
agent_type=agent_type,
directory=directory,
data_types=data_types,
system_prompt=system_prompt,
)
# Store the system prompt inside the session for reuse in /message.
session.__dict__["_system_prompt"] = system_prompt # type: ignore[index]
# The LLM will explore the directory using FILESYSTEM_TOOLS via the
# ws_context executor (already set by the WS handler before calling us).
# Seed with an initial user message — some providers (e.g. GitHub Copilot)
# require at least one user/input message to be present.
seed_history: list[dict[str, Any]] = [
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
]
ai_reply = await _call_llm_with_tools(
system_prompt=system_prompt,
history=seed_history,
tools=list(FILESYSTEM_TOOLS),
)
session.history.extend(seed_history)
session.history.append({"role": "assistant", "content": ai_reply})
_sessions[session_id] = session
logger.info("Journey session %s started for user %s (agent_type=%s)", session_id, current_user.id, body.agent_type)
return JourneyResponse(session_id=session_id, message=first_question, done=False)
logger.info(
"agent_setup: journey session %s started for user %s (directory=%s)",
session_id,
user_id,
directory,
)
@router.post("/message", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
async def send_journey_message(
body: JourneyMessageRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> JourneyResponse:
"""Send a message in an existing Chatbot Journey session.
The server appends the user's message to the conversation history,
calls the LLM, and appends the AI reply. When the LLM wraps up with a
``prompt_template`` block the response includes ``done=True`` and the
extracted template.
"""
session = _get_session(body.session_id, current_user.id)
system_prompt: str = session.__dict__.get("_system_prompt", _build_system_prompt(session.agent_type, None)) # type: ignore[assignment]
# Append user turn to history.
session.history.append({"role": "user", "content": body.message})
# Call the LLM with the full conversation so far.
ai_reply = await _call_llm(system_prompt, session.history)
# Append AI turn.
session.history.append({"role": "assistant", "content": ai_reply})
# Check if the LLM produced the final template.
# Check if the LLM produced the template on the first turn (unlikely but possible).
prompt_template = _extract_template(ai_reply)
done = prompt_template is not None
# Strip the sentinel markers from the message shown to the user.
display_message = ai_reply
if done:
display_message = (
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
or "Here is your agent configuration. You can save it or continue refining."
)
_sessions.pop(session_id, None)
if done:
logger.info("Journey session %s completed for user %s", body.session_id, current_user.id)
# Clean up the session immediately on completion.
_sessions.pop(body.session_id, None)
else:
# Nudge the LLM to wrap up after max turns.
return {
"type": "journey_reply",
"session_id": session_id,
"message": display_message,
"done": done,
"prompt_template": prompt_template,
}
async def handle_journey_message(
user_id: str,
frame: dict[str, Any],
) -> dict[str, Any]:
"""Handle a ``journey_message`` WS frame.
Appends the user message, calls the LLM, and returns the
``journey_reply`` payload.
"""
session_id = frame.get("session_id", "")
message = frame.get("message", "")
session = get_journey_session(session_id, user_id)
if session is None:
return {
"type": "journey_reply",
"session_id": session_id,
"message": "Journey session not found or expired. Please start a new setup.",
"done": True,
"prompt_template": None,
}
# Append user turn.
session.history.append({"role": "user", "content": message})
# Call the LLM with tools.
ai_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=list(FILESYSTEM_TOOLS),
)
session.history.append({"role": "assistant", "content": ai_reply})
# Check if the LLM produced the final template.
prompt_template = _extract_template(ai_reply)
done = prompt_template is not None
# If the LLM didn't produce a template but we've hit max turns, nudge it
# and call the LLM one more time to force template generation.
if not done:
turns = sum(1 for t in session.history if t["role"] == "user")
if turns >= _MAX_TURNS:
# Add a system-level nudge as a hidden user message.
session.history.append({
"role": "user",
"content": (
"[System: You have enough information. Please generate the final "
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
),
})
nudge_content = (
"[System: You have enough information. Please generate the final "
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
)
session.history.append({"role": "user", "content": nudge_content})
return JourneyResponse(
session_id=body.session_id,
message=display_message,
done=done,
prompt_template=prompt_template,
)
nudge_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=list(FILESYSTEM_TOOLS),
)
session.history.append({"role": "assistant", "content": nudge_reply})
prompt_template = _extract_template(nudge_reply)
if prompt_template is not None:
done = True
ai_reply = nudge_reply
display_message = ai_reply
if done:
display_message = (
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
if _TEMPLATE_START in ai_reply
else "Here is your agent configuration. You can save it or continue refining."
)
_sessions.pop(session_id, None)
logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id)
return {
"type": "journey_reply",
"session_id": session_id,
"message": display_message,
"done": done,
"prompt_template": prompt_template,
}

View File

@@ -1,45 +1,36 @@
"""Agent CRUD routes: local directory agents and cloud connector agents.
"""Agent routes.
Endpoints:
GET /agents/catalog — hardcoded agent type catalog
GET /agents/local — list user's local agent configs
POST /agents/local create local agent (tier-gated)
PUT /agents/local/{agent_id} — partial update (ownership check)
DELETE /agents/local/{agent_id} — delete + cascade run logs
GET /agents/cloud — list user's cloud agent configs
POST /agents/cloud — create cloud agent (tier-gated)
PUT /agents/cloud/{agent_id} — partial update (ownership check)
DELETE /agents/cloud/{agent_id} — delete + cascade run logs
GET /agents/runs — paginated run logs (agent_id, page, limit)
POST /agents/{agent_id}/run — manual trigger stub (dispatch in Step 3.4)
Backend responsibilities are intentionally minimal:
GET /agents/catalog — static catalog for UI display
POST /agents/can-create — billing eligibility check
POST /agents/triggertrigger a local agent run
Agent configuration is owned by the Electron app and is not persisted
in backend agent-config tables.
"""
from __future__ import annotations
import asyncio
from datetime import datetime
from typing import Any
import uuid
from datetime import datetime, timedelta, timezone
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import func, or_, select
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import FEATURES
from app.core.agent_runner import run_cloud_agent, run_local_agent
from app.core.agent_runner import run_local_agent
from app.core.device_manager import device_manager
from app.db import get_session
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
from app.models import AgentRunLog, LocalAgentConfig
from app.schemas import (
AgentCatalogItem,
AgentCreationCheckRequest,
AgentCreationCheckResponse,
AgentRunLogResponse,
CloudAgentConfigCreate,
CloudAgentConfigResponse,
CloudAgentConfigUpdate,
LocalAgentConfigCreate,
LocalAgentConfigResponse,
LocalAgentConfigUpdate,
AgentTriggerRequest,
UserProfile,
)
@@ -56,39 +47,21 @@ def _dt_ms_opt(dt: datetime | None) -> int | None:
return int(dt.timestamp() * 1000) if dt else None
# ── Model → schema converters ─────────────────────────────────────────
def _to_local_response(a: LocalAgentConfig) -> LocalAgentConfigResponse:
return LocalAgentConfigResponse(
id=a.id,
name=a.name,
device_id=a.device_id,
directory_paths=a.directory_paths,
data_types=a.data_types,
prompt_template=a.prompt_template,
file_extensions=a.file_extensions,
schedule_cron=a.schedule_cron,
enabled=a.enabled,
last_run_at=_dt_ms_opt(a.last_run_at),
created_at=_dt_ms(a.created_at),
updated_at=_dt_ms(a.updated_at),
)
def _to_cloud_response(a: CloudAgentConfig) -> CloudAgentConfigResponse:
return CloudAgentConfigResponse(
id=a.id,
provider=a.provider, # type: ignore[arg-type]
name=a.name,
data_types=a.data_types,
prompt_template=a.prompt_template,
schedule_cron=a.schedule_cron,
filter_config=a.filter_config,
enabled=a.enabled,
last_run_at=_dt_ms_opt(a.last_run_at),
created_at=_dt_ms(a.created_at),
updated_at=_dt_ms(a.updated_at),
)
def _to_data_types(values: list[str]) -> list[str]:
normalize = {
"task": "tasks", "tasks": "tasks",
"note": "notes", "notes": "notes",
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
"project": "projects", "projects": "projects",
}
seen: set[str] = set()
result: list[str] = []
for v in values:
mapped = normalize.get(v)
if mapped and mapped not in seen:
seen.add(mapped)
result.append(mapped)
return result
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
@@ -105,77 +78,42 @@ def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
)
# ── Ownership-checked lookups ─────────────────────────────────────────
async def _get_local_agent_for_user(
agent_id: str, user_id: str, db: AsyncSession
) -> LocalAgentConfig:
result = await db.execute(
select(LocalAgentConfig).where(
LocalAgentConfig.id == agent_id,
LocalAgentConfig.user_id == user_id,
)
)
record = result.scalar_one_or_none()
if record is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
return record
async def _get_cloud_agent_for_user(
agent_id: str, user_id: str, db: AsyncSession
) -> CloudAgentConfig:
result = await db.execute(
select(CloudAgentConfig).where(
CloudAgentConfig.id == agent_id,
CloudAgentConfig.user_id == user_id,
)
)
record = result.scalar_one_or_none()
if record is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
return record
# ── Tier limit helper ─────────────────────────────────────────────────
async def _count_enabled_agents(user_id: str, db: AsyncSession) -> int:
"""Return combined enabled local + cloud agent count for the user."""
local_count = (
await db.execute(
select(func.count(LocalAgentConfig.id)).where(
LocalAgentConfig.user_id == user_id,
LocalAgentConfig.enabled == True, # noqa: E712
)
)
).scalar_one()
cloud_count = (
await db.execute(
select(func.count(CloudAgentConfig.id)).where(
CloudAgentConfig.user_id == user_id,
CloudAgentConfig.enabled == True, # noqa: E712
)
)
).scalar_one()
return local_count + cloud_count
def _enforce_agent_limit(tier: str, current_count: int) -> None:
def _enforce_agent_limit(tier: str, current_count: int) -> int:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
if limit != -1 and current_count >= limit:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
)
return limit
# ── Local page schema (used by runs endpoint) ─────────────────────────
async def _enforce_run_frequency(
tier: str,
user_id: str,
db: AsyncSession,
) -> None:
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
if limit == -1:
return # unlimited
class _RunsPage(BaseModel):
total: int
page: int
limit: int
items: list[AgentRunLogResponse]
today_start = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
)
result = await db.execute(
select(func.count(AgentRunLog.id)).where(
AgentRunLog.user_id == user_id,
AgentRunLog.started_at >= today_start,
)
)
runs_today: int = result.scalar_one()
if runs_today >= limit:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
)
# ── Catalog ───────────────────────────────────────────────────────────
@@ -209,229 +147,55 @@ async def get_agent_catalog(
]
# ── Local agent CRUD ──────────────────────────────────────────────────
@router.get("/local", response_model=list[LocalAgentConfigResponse])
async def list_local_agents(
@router.post("/can-create", response_model=AgentCreationCheckResponse)
async def can_create_agent(
body: AgentCreationCheckRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[LocalAgentConfigResponse]:
"""List all local directory agent configs owned by the authenticated user."""
result = await db.execute(
select(LocalAgentConfig).where(LocalAgentConfig.user_id == current_user.id)
)
return [_to_local_response(a) for a in result.scalars().all()]
) -> AgentCreationCheckResponse:
"""Check if the user can create one more agent based on billing tier.
@router.post("/local", response_model=LocalAgentConfigResponse, status_code=status.HTTP_201_CREATED)
async def create_local_agent(
body: LocalAgentConfigCreate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> LocalAgentConfigResponse:
"""Create a new local directory agent config.
The combined count of enabled local and cloud agents for the user is
checked against the ``batch_active`` limit for their billing tier.
Since configuration is client-owned, the Electron app sends its current
active agent count and the backend applies tier limits.
"""
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
agent = LocalAgentConfig(
user_id=current_user.id,
name=body.name,
device_id=body.device_id,
directory_paths=body.directory_paths,
data_types=body.data_types,
prompt_template=body.prompt_template,
file_extensions=body.file_extensions,
schedule_cron=body.schedule_cron,
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
allowed = limit == -1 or body.active_agents < limit
return AgentCreationCheckResponse(
allowed=allowed,
tier=current_user.tier,
active_agents=body.active_agents,
limit=limit,
)
db.add(agent)
await db.commit()
await db.refresh(agent)
return _to_local_response(agent)
@router.put("/local/{agent_id}", response_model=LocalAgentConfigResponse)
async def update_local_agent(
agent_id: str,
body: LocalAgentConfigUpdate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> LocalAgentConfigResponse:
"""Partially update a local agent config. Only provided fields are changed."""
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
for field, value in body.model_dump(exclude_unset=True).items():
setattr(agent, field, value)
await db.commit()
await db.refresh(agent)
return _to_local_response(agent)
@router.delete("/local/{agent_id}", response_model=dict)
async def delete_local_agent(
agent_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a local agent config. Associated run logs are cascade-deleted."""
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
await db.delete(agent)
await db.commit()
return {"ok": True}
# ── Cloud agent CRUD ──────────────────────────────────────────────────
@router.get("/cloud", response_model=list[CloudAgentConfigResponse])
async def list_cloud_agents(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[CloudAgentConfigResponse]:
"""List all cloud connector agent configs owned by the authenticated user."""
result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.user_id == current_user.id)
)
return [_to_cloud_response(a) for a in result.scalars().all()]
@router.post("/cloud", response_model=CloudAgentConfigResponse, status_code=status.HTTP_201_CREATED)
async def create_cloud_agent(
body: CloudAgentConfigCreate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> CloudAgentConfigResponse:
"""Create a new cloud connector agent config.
The combined count of enabled local and cloud agents for the user is
checked against the ``batch_active`` limit for their billing tier.
"""
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
agent = CloudAgentConfig(
user_id=current_user.id,
provider=body.provider,
name=body.name,
data_types=body.data_types,
prompt_template=body.prompt_template,
oauth_token_encrypted=body.oauth_token_encrypted,
schedule_cron=body.schedule_cron,
filter_config=body.filter_config,
)
db.add(agent)
await db.commit()
await db.refresh(agent)
return _to_cloud_response(agent)
@router.put("/cloud/{agent_id}", response_model=CloudAgentConfigResponse)
async def update_cloud_agent(
agent_id: str,
body: CloudAgentConfigUpdate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> CloudAgentConfigResponse:
"""Partially update a cloud agent config. Only provided fields are changed."""
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
for field, value in body.model_dump(exclude_unset=True).items():
setattr(agent, field, value)
await db.commit()
await db.refresh(agent)
return _to_cloud_response(agent)
@router.delete("/cloud/{agent_id}", response_model=dict)
async def delete_cloud_agent(
agent_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a cloud agent config. Associated run logs are cascade-deleted."""
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
await db.delete(agent)
await db.commit()
return {"ok": True}
# ── Run logs ──────────────────────────────────────────────────────────
@router.get("/runs", response_model=_RunsPage)
async def list_run_logs(
agent_id: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
limit: int = Query(default=20, ge=1, le=100),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _RunsPage:
"""Return paginated run logs for the authenticated user.
Optionally filter by ``agent_id``. Results are ordered from newest to oldest.
"""
base_filter = [AgentRunLog.user_id == current_user.id]
if agent_id:
base_filter.append(AgentRunLog.agent_id == agent_id)
total = (
await db.execute(select(func.count(AgentRunLog.id)).where(*base_filter))
).scalar_one()
result = await db.execute(
select(AgentRunLog)
.where(*base_filter)
.order_by(AgentRunLog.started_at.desc())
.offset((page - 1) * limit)
.limit(limit)
)
items = [_to_run_log_response(log) for log in result.scalars().all()]
return _RunsPage(total=total, page=page, limit=limit, items=items)
# ── Manual trigger stub ───────────────────────────────────────────────
@router.post("/{agent_id}/run", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
async def trigger_agent_run(
agent_id: str,
body: AgentTriggerRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> AgentRunLogResponse:
"""Manually trigger an agent run.
"""Trigger a local agent run using client-provided configuration."""
_enforce_agent_limit(current_user.tier, body.active_agents)
await _enforce_run_frequency(current_user.tier, current_user.id, db)
Looks up the agent config (local or cloud) by ID with ownership check,
creates a run log entry with ``status="running"``, and returns it.
Actual dispatch to the agent runner is wired in Step 3.4 once
``DeviceConnectionManager`` and ``agent_runner`` are available.
"""
# Determine agent type by trying local first, then cloud.
# Keep the full config object so we can pass it to the agent runner.
local_config: LocalAgentConfig | None = None
cloud_config: CloudAgentConfig | None = None
local_result = await db.execute(
select(LocalAgentConfig).where(
LocalAgentConfig.id == agent_id,
LocalAgentConfig.user_id == current_user.id,
)
config = LocalAgentConfig(
id=str(uuid.uuid4()),
user_id=current_user.id,
device_id=body.device_id,
name="Local Directory Monitor",
directory_paths=[body.directory],
data_types=_to_data_types(body.what_to_extract),
prompt_template=body.custom_agent_prompt,
file_extensions=[],
schedule_cron=body.batch_interval,
enabled=True,
)
local_config = local_result.scalar_one_or_none()
if local_config is not None:
agent_type = "local"
else:
cloud_result = await db.execute(
select(CloudAgentConfig).where(
CloudAgentConfig.id == agent_id,
CloudAgentConfig.user_id == current_user.id,
)
)
cloud_config = cloud_result.scalar_one_or_none()
if cloud_config is not None:
agent_type = "cloud"
else:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
stable_agent_id = body.agent_id or config.id
run_log = AgentRunLog(
agent_id=agent_id,
agent_type=agent_type,
agent_id=stable_agent_id,
agent_type="local",
user_id=current_user.id,
status="running",
)
@@ -439,14 +203,14 @@ async def trigger_agent_run(
await db.commit()
await db.refresh(run_log)
# Dispatch the run as a background task — returns 202 immediately.
if agent_type == "local" and local_config is not None:
asyncio.create_task(
run_local_agent(current_user.id, local_config, run_log, device_manager)
)
elif agent_type == "cloud" and cloud_config is not None:
asyncio.create_task(
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
)
run_context = {
"type": "agent_batch",
"run_id": run_log.id,
"agent_id": stable_agent_id,
}
asyncio.create_task(
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
)
return _to_run_log_response(run_log)

View File

@@ -10,9 +10,7 @@ from fastapi.responses import JSONResponse
from app.api.deps import get_current_user
from app.core.deep_agent import run_home
from app.core.memory_middleware import MemoryMiddleware
from app.db import async_session
from app.schemas import ChatRequest, ChatResponse, UserProfile
from app.schemas import ChatRequest, UserProfile
router = APIRouter(prefix="/chat", tags=["chat"])
@@ -22,21 +20,10 @@ async def chat(
body: ChatRequest,
current_user: UserProfile = Depends(get_current_user),
) -> JSONResponse:
"""Route a chat message through the Home deep agent (non-streaming)."""
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(current_user.id, body.message)
context = {
**body.context.model_dump(),
**memory_context,
}
response_text = await run_home(
"""REST fallback for home chat when websocket streaming is unavailable."""
response = await run_home(
user_id=current_user.id,
message=body.message,
context=context,
db_session_factory=async_session,
context=body.context.model_dump(),
)
result = ChatResponse(response=response_text)
return JSONResponse(content=result.model_dump())
return JSONResponse(content={"response": response})

View File

@@ -14,11 +14,11 @@ Protocol:
4. Session enters message dispatch loop + heartbeat.
Incoming frame dispatch:
- ``tool_result`` → resolves a pending tool-call Future.
- ``agent_data`` enqueued in the per-run agent data queue.
- ``agent_complete`` → sends None sentinel to close the queue stream.
- ``pong`` → heartbeat acknowledgement (updates last-seen).
- unknown types → logged, ignored.
- ``tool_result`` → resolves a pending tool-call Future.
- ``journey_start`` → starts a guided setup journey session.
- ``journey_message`` continues a journey conversation.
- ``pong`` → heartbeat acknowledgement (updates last-seen).
- unknown types → logged, ignored.
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
@@ -39,12 +39,13 @@ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from jose import JWTError, jwt
from sqlalchemy import update
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs
from app.core.deep_agent import run_floating_stream, run_home_stream
from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware
from app.core.deep_agent import run_home_stream, run_floating_stream
from app.core.output_formatter import HomeFormatter, FloatingFormatter
from app.core.output_formatter import StreamFormatter
from app.core.ws_context import clear_client_executor, set_client_executor
from app.db import async_session
from app.models import AgentRunLog
@@ -147,37 +148,6 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
"device_ws: tool_result missing id from user=%s", user_id
)
elif frame_type == WsFrameType.agent_data:
run_id = frame.get("run_id")
if run_id:
try:
queue = device_manager.get_agent_data_queue(user_id, run_id)
await queue.put(frame)
except RuntimeError:
logger.warning(
"device_ws: agent_data for unknown run user=%s run=%s",
user_id,
run_id,
)
else:
logger.warning(
"device_ws: agent_data missing run_id from user=%s", user_id
)
elif frame_type == WsFrameType.agent_complete:
run_id = frame.get("run_id")
if run_id:
try:
queue = device_manager.get_agent_data_queue(user_id, run_id)
# Sentinel: signals the agent data stream is finished.
await queue.put(None)
except RuntimeError:
pass
else:
logger.warning(
"device_ws: agent_complete missing run_id from user=%s", user_id
)
elif frame_type == WsFrameType.home_request:
asyncio.create_task(
_handle_home_request(websocket, user_id, frame)
@@ -188,6 +158,16 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
_handle_floating_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.journey_start:
asyncio.create_task(
_handle_journey_start(websocket, user_id, frame)
)
elif frame_type == WsFrameType.journey_message:
asyncio.create_task(
_handle_journey_message(websocket, user_id, frame)
)
elif frame_type == "pong":
# Heartbeat ack — nothing to do, connection is alive.
pass
@@ -200,35 +180,13 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
# ── v3 Chat Handlers ──────────────────────────────────────────────────
_WS_TOOL_CALL_TIMEOUT = 30 # seconds to wait for Electron tool_result
async def _make_ws_executor(websocket: WebSocket, user_id: str):
"""Return a callback that sends tool_call frames and awaits tool_result."""
async def _executor(payload: dict) -> dict:
payload["type"] = WsFrameType.tool_call
call_id = payload["id"]
logger.info("ws_executor: sending tool_call id=%s action=%s", call_id, payload.get("action"))
await websocket.send_text(json.dumps(payload))
future = device_manager.create_pending_call(user_id, call_id)
try:
result = await asyncio.wait_for(future, timeout=_WS_TOOL_CALL_TIMEOUT)
except asyncio.TimeoutError:
logger.error(
"ws_executor: timeout waiting for tool_result id=%s action=%s user=%s",
call_id, payload.get("action"), user_id,
)
# Clean up the pending future so it doesn't leak
conn = device_manager._connections.get(user_id)
if conn:
conn.pending_calls.pop(call_id, None)
return {"error": f"Tool call timed out after {_WS_TOOL_CALL_TIMEOUT}s", "rows": []}
logger.info("ws_executor: tool_result id=%s result_type=%s result_keys=%s",
call_id, type(result).__name__,
list(result.keys()) if isinstance(result, dict) else "N/A")
if result is None:
logger.error("ws_executor: future resolved to None for call_id=%s user=%s", call_id, user_id)
return result
future = device_manager.create_pending_call(user_id, payload["id"])
return await future
return _executor
@@ -241,14 +199,27 @@ async def _handle_home_request(
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
logger.info(
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
user_id,
request_id,
session_id,
message[:200],
)
# ── Memory: enrich context before LLM call ────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(user_id, message)
memory_context = await memory.enrich_context(
user_id,
message,
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"conversation_history": frame.get("conversation_history", []),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
@@ -256,12 +227,11 @@ async def _handle_home_request(
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_home_stream(
user_id, message, context, db_session_factory=async_session
)
formatter = HomeFormatter(request_id=request_id)
event_stream = run_home_stream(user_id, message, context)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
# Collect text chunks to build the full response for episode storage
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc:
@@ -276,8 +246,15 @@ async def _handle_home_request(
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks)
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
)
logger.info(
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
user_id,
request_id,
session_id,
len("".join(response_chunks)),
)
async def _handle_floating_request(
@@ -290,23 +267,37 @@ async def _handle_floating_request(
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
scope: dict = frame.get("scope", {})
logger.info(
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
user_id,
request_id,
session_id,
json.dumps(scope, ensure_ascii=True)[:200],
message[:200],
)
# ── Memory: enrich context before LLM call ────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(user_id, message)
memory_context = await memory.enrich_context(
user_id,
message,
trace_id=request_id,
session_id=session_id,
)
context: dict = {"scope": scope, **memory_context}
context: dict = {
"scope": scope,
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_floating_stream(
user_id, message, context, scope=scope,
db_session_factory=async_session,
)
formatter = FloatingFormatter(request_id=request_id)
event_stream = run_floating_stream(user_id, message, context)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
if ws_frame.type == "stream_text": # type: ignore[union-attr]
@@ -323,8 +314,72 @@ async def _handle_floating_request(
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks)
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
)
logger.info(
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
user_id,
request_id,
session_id,
len("".join(response_chunks)),
)
# ── v4 Journey Handlers ─────────────────────────────────────────────
async def _handle_journey_start(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a journey_start frame — explores directory and sends first question."""
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
try:
reply = await handle_journey_start(user_id, frame)
await websocket.send_text(json.dumps(reply))
except Exception as exc:
logger.error(
"device_ws: journey_start failed user=%s: %s", user_id, exc
)
await websocket.send_text(json.dumps({
"type": "journey_reply",
"session_id": frame.get("session_id", ""),
"message": f"Failed to start journey: {exc}",
"done": True,
"prompt_template": None,
}))
finally:
clear_client_executor()
async def _handle_journey_message(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a journey_message frame — continues the journey conversation."""
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
try:
reply = await handle_journey_message(user_id, frame)
await websocket.send_text(json.dumps(reply))
except Exception as exc:
session_id = frame.get("session_id", "")
logger.error(
"device_ws: journey_message failed user=%s session=%s: %s",
user_id, session_id, exc,
)
await websocket.send_text(json.dumps({
"type": "journey_reply",
"session_id": session_id,
"message": f"Journey error: {exc}",
"done": True,
"prompt_template": None,
}))
finally:
clear_client_executor()
# ── Heartbeat ─────────────────────────────────────────────────────────
@@ -360,6 +415,3 @@ async def _mark_runs_disconnected(user_id: str) -> None:
user_id,
exc,
)

View File

@@ -21,6 +21,7 @@ FEATURES: dict[str, dict[str, Any]] = {
"free": {
"agents": 3,
"batch_active": 2,
"batch_runs_per_day": 5,
"cloud_storage_gb": 0,
"backup_gb": 0,
"providers": 1,
@@ -31,6 +32,7 @@ FEATURES: dict[str, dict[str, Any]] = {
"pro": {
"agents": -1, # unlimited
"batch_active": 10,
"batch_runs_per_day": 50,
"cloud_storage_gb": 5,
"backup_gb": 5,
"providers": -1,
@@ -41,6 +43,7 @@ FEATURES: dict[str, dict[str, Any]] = {
"power": {
"agents": -1,
"batch_active": -1, # unlimited
"batch_runs_per_day": -1, # unlimited
"cloud_storage_gb": 25,
"backup_gb": 25,
"providers": -1,
@@ -51,6 +54,7 @@ FEATURES: dict[str, dict[str, Any]] = {
"team": {
"agents": -1,
"batch_active": -1,
"batch_runs_per_day": -1, # unlimited
"cloud_storage_gb": -1, # unlimited
"backup_gb": -1, # unlimited
"providers": -1,
@@ -77,16 +81,18 @@ class TierManager:
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
"""Return the current billing tier for ``user_id`` from the DB.
Falls back to ``'free'`` when no subscription row exists.
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod
when no subscription row exists.
"""
from app.models import Subscription # noqa: PLC0415
from app.config.settings import settings # noqa: PLC0415
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
tier: str | None = result.scalar_one_or_none()
if tier is None or tier not in FEATURES:
return "free"
return "power" if settings.ENV == "dev" else "free"
return tier # type: ignore[return-value]
# ── Feature access ───────────────────────────────────────────────────

View File

@@ -0,0 +1,30 @@
"""Minimal agent base types retained for compatibility with batch runners."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
class BaseAgent(ABC):
"""Common base for non-chat agents still using the old base contract."""
def __init__(
self,
user_id: str = "",
shared_memory: dict[str, Any] | None = None,
vector_store_context: list[str] | None = None,
) -> None:
self.user_id = user_id
self.shared_memory: dict[str, Any] = shared_memory or {}
self.vector_store_context: list[str] = vector_store_context or []
@abstractmethod
def get_name(self) -> str: ...
@abstractmethod
def get_description(self) -> str: ...
@property
def skills(self) -> list[str]:
return []

View File

@@ -1,15 +1,15 @@
"""Agent run manager.
"""Agent run orchestrator.
Drives two agent types:
* **Local directory agent** — sends an ``agent_run`` frame to the connected
Electron device, waits for the device to stream back file contents via
``agent_data`` frames, then calls the LLM to extract structured items from
each file and pushes inserts to Electron via tool-call round-trips.
* **Local directory agent** — two-phase execution that mirrors the
``deep_agent.py`` tool-calling pattern. Phase 1 (Triage) explores the
user's directory via file-system tools and groups files by project.
Phase 2 (Processing) reads full file contents and performs CRUD
operations using the standard entity tools (tasks, notes, etc.).
* **Cloud connector agent** — fetches data from third-party APIs (Gmail,
Teams, Outlook) and pushes extracted items to Electron. **This path is
a stub** — provider integrations are implemented in Step 3.6.
Teams, Outlook) and pushes extracted items to Electron.
Usage
-----
@@ -33,11 +33,17 @@ from datetime import datetime, timedelta, timezone
from typing import Any
from croniter import croniter
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from sqlalchemy import select
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
from app.agents.note_agent import NOTE_TOOLS
from app.agents.project_agent import PROJECT_TOOLS
from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS
from app.core.device_manager import DeviceConnectionManager
from app.core.llm import get_llm
from app.core.ws_context import clear_client_executor, set_client_executor
from app.db import async_session
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
@@ -45,50 +51,108 @@ logger = logging.getLogger(__name__)
# ── Timeouts ───────────────────────────────────────────────────────────────
# Max seconds to wait for Electron to finish streaming file data.
_FILE_READ_TIMEOUT: int = 120
# Max seconds to wait for Electron to acknowledge a single tool-call insert.
_INSERT_TIMEOUT: int = 30
# Max seconds to wait for a single tool-call round-trip (FE → BE).
_TOOL_CALL_TIMEOUT: int = 30
# Max LLM reasoning steps per phase.
_MAX_TRIAGE_STEPS: int = 10
_MAX_PROCESSING_STEPS: int = 12
# ── Allowed tables & extraction schema hints ───────────────────────────────
# ── Data-type to tool mapping ─────────────────────────────────────────────
_ALLOWED_TABLES: frozenset[str] = frozenset(
{"tasks", "notes", "timelines", "projects", "taskComments"}
)
# Field descriptions fed to the extraction LLM as concise schema references.
_TABLE_SCHEMAS: dict[str, str] = {
"tasks": (
"title (str, required), description (str), "
"status (todo|in_progress|done, default todo), "
"priority (high|medium|low, default medium), "
"assignee (JSON array string), dueDate (ms timestamp int), projectId (str)"
),
"notes": "title (str, required), content (str, markdown), projectId (str)",
"timelines": (
"title (str, required), projectId (str, required), date (ms timestamp int)"
),
"projects": "name (str, required), clientId (str)",
"taskComments": "taskId (str, required), author (str), content (str, required)",
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
"tasks": TASK_TOOLS,
"projects": PROJECT_TOOLS,
"notes": NOTE_TOOLS,
"timelines": TIMELINE_TOOLS,
}
_EXTRACTION_SYSTEM_PROMPT = """\
You are a data extraction assistant for a freelance project management tool.
Given a document, extract structured records matching the user's instructions.
# ── Triage prompt ─────────────────────────────────────────────────────────
Output a JSON array (no markdown fences, no explanation) of objects shaped:
[{{"table": "<table_name>", "data": {{...fields}}}}, ...]
_TRIAGE_SYSTEM_PROMPT = """\
You are a file triage assistant for a freelance project management tool.
Your job is to explore a local directory on the user's device, understand its
structure, and group files by project context.
Allowed table names and their fields:
{table_schemas}
You have access to these tools:
- list_directory: to map folder structure
- get_file_metadata: to check creation/modification dates
- read_file_content: to read brief snippets when needed for categorisation
- list_projects / list_all_projects / get_project: to fetch existing projects
from the user's workspace and match files to them
Rules:
- Only extract tables listed in the "data_types" instructions.
- Use camelCase field names exactly as shown above.
- Omit optional fields you cannot determine; do not invent data.
- Never include id, createdAt, updatedAt, isAiSuggested, or isApproved.
- If nothing relevant is found, return an empty JSON array: []
- Return ONLY the JSON array.
Instructions:
1. Start by calling list_directory on the configured root path.
2. Explore subdirectories as needed to understand the structure.
3. Use get_file_metadata to check modification dates. Skip files that have
NOT been modified since: {last_run_at}.
4. Call list_all_projects to get the user's existing projects.
5. Match files to existing projects by name, folder structure, or content hints.
6. If files don't match any existing project, group them under "standalone".
{custom_prompt_section}
Target entity types to extract: {data_types}
File extensions to consider: {file_extensions}
When you have finished exploring, output ONLY a JSON object (no markdown
fences, no explanation) mapping project IDs or "standalone" to file path
arrays:
{{"<project_id>": ["<file_path>", ...], "standalone": ["<file_path>", ...]}}
Return ONLY the JSON object as your final message.
"""
# ── Processing prompt ─────────────────────────────────────────────────────
_PROCESSING_BASE_PROMPT = """\
You are a data extraction and management assistant for a freelance project
management tool.
Available tools:
Filesystem : read_file_content, list_directory, get_file_metadata
Tasks : list_tasks, create_task, update_task, add_task_comment
Notes : list_notes, get_note, create_note, update_note
Timelines : list_timelines, create_timeline, update_timeline
Projects : list_all_projects, get_project, create_project, update_project
Your task:
1. Read the full content of each file below using read_file_content.
2. For each piece of information found, ALWAYS try to match and update an
existing record before creating a new one.
3. ONLY act on these entity types: {data_types}.
4. Do NOT invent data. Only extract what is clearly present in the files.
5. If a file contains no relevant data for the target entity types, skip it.
Update-first rules (apply in this order):
Tasks:
- Call list_tasks to find a match by title or context.
- If found: call add_task_comment (author "Adiuva"), update_task to set
assignees, state (ToDo / In Progress / Completed), or other fields.
- If NOT found: call create_task with isAiSuggested=1, isApproved=0.
Timelines:
- Call list_timelines to find a match by title or date.
- If found: call update_timeline to edit fields or mark it complete.
- If NOT found: call create_timeline with isAiSuggested=1, isApproved=0.
Notes:
- Call list_notes to find a match by title or topic, then get_note to
read its current content.
- If found: call update_note with the merged content.
- If NOT found: call create_note with isAiSuggested=1, isApproved=0.
Projects:
- Call list_all_projects to check for a match first.
- Only call create_project if the information is clearly significant and
no existing project matches. Set isAiSuggested=1, isApproved=0.
{project_context}
Files to process:
{file_list}
{custom_prompt_section}
After processing all files, respond with a brief summary of what you updated
and what you created.
"""
@@ -118,100 +182,151 @@ def _is_overdue(schedule_cron: str, last_run_at: datetime | None) -> bool:
return False # Fail-safe: don't trigger if expression is invalid.
# ── LLM extraction ─────────────────────────────────────────────────────────
# ── WS executor for agent context ─────────────────────────────────────────
async def _extract_items_from_content(
prompt_template: str,
file_content: str,
data_types: list[str],
) -> list[dict[str, Any]]:
"""Call the LLM to extract structured records from *file_content*.
Returns a validated list of ``{table: str, data: dict}`` objects.
Items referencing tables not in *data_types* are discarded.
"""
allowed = [t for t in data_types if t in _ALLOWED_TABLES]
if not allowed:
return []
schema_text = "\n".join(
f" {table}: {_TABLE_SCHEMAS.get(table, '(unknown)')}" for table in allowed
)
system_prompt = _EXTRACTION_SYSTEM_PROMPT.format(table_schemas=schema_text)
user_prompt = (
f"User instructions: {prompt_template}\n\n"
f"Extract these record types: {', '.join(allowed)}\n\n"
f"Document:\n{file_content[:8000]}"
)
llm = get_llm()
raw = ""
try:
response = await llm.ainvoke(
[SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)]
)
raw = str(response.content).strip()
items: list[dict] = json.loads(raw)
if not isinstance(items, list):
raise ValueError("LLM response is not a JSON array")
except json.JSONDecodeError as exc:
logger.warning(
"agent_runner: LLM extraction returned invalid JSON: %s — snippet: %.200r",
exc,
raw,
)
return []
# Other exceptions (LLM API errors, network errors) propagate to the
# caller (run_local_agent) which records them per-file in the run log.
validated: list[dict[str, Any]] = []
for item in items:
table = item.get("table")
data = item.get("data")
if not isinstance(table, str) or table not in allowed:
continue
if not isinstance(data, dict) or not data:
continue
# Strip any server-generated or forbidden fields.
for _field in ("id", "createdAt", "updatedAt", "isAiSuggested", "isApproved"):
data.pop(_field, None)
validated.append({"table": table, "data": data})
return validated
# ── Tool-call insert helper ─────────────────────────────────────────────────
async def _send_insert_to_client(
def _make_agent_executor(
user_id: str,
table: str,
data: dict[str, Any],
device_mgr: DeviceConnectionManager,
) -> dict[str, Any]:
"""Send an ``insert`` tool_call frame to Electron and await the tool_result.
run_context: dict | None = None,
) -> Any:
"""Create a WS callback for ``set_client_executor()`` so that all tools
can use ``execute_on_client()`` during an agent run.
All inserts include ``isAiSuggested=1, isApproved=0`` so the user can
review AI-produced records before they are treated as confirmed.
Raises ``asyncio.TimeoutError`` if Electron does not respond within
``_INSERT_TIMEOUT`` seconds. Raises ``RuntimeError`` if the device
disconnects before the frame can be sent.
If *run_context* is provided it is attached to every ``tool_call`` frame
so the Electron client can attribute actions to the correct agent run.
"""
call_id = str(uuid.uuid4())
payload: dict[str, Any] = {
"type": "tool_call",
"id": call_id,
"action": "insert",
"table": table,
"data": {**data, "isAiSuggested": 1, "isApproved": 0},
}
fut = device_mgr.create_pending_call(user_id, call_id)
await device_mgr.send_frame(user_id, payload)
return await asyncio.wait_for(fut, timeout=_INSERT_TIMEOUT)
async def _executor(payload: dict) -> dict:
payload["type"] = "tool_call"
if run_context:
payload["run_context"] = run_context
call_id = payload["id"]
fut = device_mgr.create_pending_call(user_id, call_id)
await device_mgr.send_frame(user_id, payload)
return await asyncio.wait_for(fut, timeout=_TOOL_CALL_TIMEOUT)
return _executor
# ── Local agent runner ──────────────────────────────────────────────────────
# ── LLM tool-calling loop (mirrors deep_agent._run_single_agent) ──────────
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
async def _run_agent_with_tools(
*,
system_prompt: str,
user_message: str,
tools: list[Any],
max_steps: int,
) -> str:
"""Run an LLM agent with tool-calling, returning the final text response.
Follows the same pattern as ``deep_agent._run_single_agent``:
bind tools → invoke → handle tool calls → repeat until final text.
"""
llm = get_llm()
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(content=user_message),
]
tool_calls_count = 0
tool_map = {tool_def.name: tool_def for tool_def in tools}
for _ in range(max_steps):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return _as_text(response.content)
for call in response.tool_calls:
tool_calls_count += 1
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"agent_runner: tool_call name=%s args=%s",
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"agent_runner: tool_result name=%s output=%s",
call_name,
str(tool_output)[:1200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
# Fallback: exceeded max steps, get final response without tools.
final = await llm.ainvoke(messages)
return _as_text(final.content)
# ── Triage map parser ─────────────────────────────────────────────────────
def _parse_triage_map(raw: str) -> dict[str, list[str]] | None:
"""Extract the JSON triage map from the LLM's final response."""
text = raw.strip()
# Try direct parse first.
try:
parsed = json.loads(text)
if isinstance(parsed, dict):
return {k: v for k, v in parsed.items() if isinstance(v, list)}
except json.JSONDecodeError:
pass
# Try extracting JSON from markdown fences or surrounding text.
import re
match = re.search(r"\{[\s\S]*\}", text)
if match:
try:
parsed = json.loads(match.group(0))
if isinstance(parsed, dict):
return {k: v for k, v in parsed.items() if isinstance(v, list)}
except json.JSONDecodeError:
pass
return None
# ── Tool list builder ─────────────────────────────────────────────────────
def _build_processing_tools(data_types: list[str]) -> list[Any]:
"""Build the tool list for Phase 2 based on user's data_types selection."""
tools: list[Any] = list(FILESYSTEM_TOOLS)
for dt in data_types:
dt_tools = _DATA_TYPE_TOOLS.get(dt)
if dt_tools:
tools.extend(dt_tools)
return tools
# ── Local agent runner (two-phase) ─────────────────────────────────────────
async def run_local_agent(
@@ -219,144 +334,163 @@ async def run_local_agent(
config: LocalAgentConfig,
run_log: AgentRunLog,
device_mgr: DeviceConnectionManager,
run_context: dict | None = None,
) -> None:
"""Execute a local directory agent run end-to-end.
"""Execute a local directory agent run using two-phase LLM-with-tools.
Steps:
Phase 1 — Triage:
Explore the directory structure, check metadata, match files to
existing projects. Output: a JSON map of project → file paths.
1. Verify the device identified by ``config.device_id`` is currently online.
2. Pre-create the agent_data queue so no incoming frames are lost.
3. Send ``agent_run`` frame to Electron (paths, extensions, prompt, data_types).
4. Consume ``agent_data`` frames until the ``None`` sentinel from
``agent_complete``.
5. For each received file call the LLM to extract ``{table, data}`` items.
6. Push each item to Electron as an ``insert`` tool-call; include
``isAiSuggested=1, isApproved=0`` so users can review AI suggestions.
7. Persist the run outcome (status, counts, errors) and update
``config.last_run_at``.
Phase 2 — Processing:
For each project group, read full file contents and perform CRUD
operations using the standard entity tools.
"""
run_id = run_log.id
# ── 1. Device online check ─────────────────────────────────────────
if not device_mgr.is_online(user_id, config.device_id):
# ── Device online check ─────────────────────────────────────────
target_device_id = config.device_id.strip() if isinstance(config.device_id, str) else ""
if target_device_id:
is_online = device_mgr.is_online(user_id, target_device_id)
else:
is_online = device_mgr.is_online(user_id)
if not is_online:
logger.info(
"agent_runner: skip run=%s — device %r offline for user=%s",
run_id,
config.device_id,
target_device_id or "<any>",
user_id,
)
await _finalize_run(
run_log,
status="error",
errors=[f"Device {config.device_id!r} is not connected"],
errors=[f"Device {target_device_id or '<any>'!r} is not connected"],
)
return
# ── 2. Pre-create agent_data queue ────────────────────────────────
try:
device_mgr.get_agent_data_queue(user_id, run_id)
except RuntimeError:
await _finalize_run(
run_log,
status="error",
errors=["Device disconnected before agent run could start"],
)
return
# ── Set up WS executor for tools ────────────────────────────────
executor = _make_agent_executor(user_id, device_mgr, run_context)
set_client_executor(executor)
# ── 3. Send agent_run frame ────────────────────────────────────────
frame: dict[str, Any] = {
"type": "agent_run",
"run_id": run_id,
"agent_id": config.id,
"config": {
"paths": config.directory_paths,
"file_extensions": config.file_extensions,
"prompt_template": config.prompt_template,
"data_types": config.data_types,
},
}
try:
await device_mgr.send_frame(user_id, frame)
except RuntimeError as exc:
device_mgr.cleanup_agent_data_queue(user_id, run_id)
await _finalize_run(
run_log,
status="error",
errors=[f"Failed to send agent_run frame: {exc}"],
)
return
logger.info(
"agent_runner: sent agent_run run=%s agent=%s user=%s",
run_id,
config.id,
user_id,
)
# ── 4. Consume agent_data frames ──────────────────────────────────
files: list[dict[str, Any]] = []
errors: list[str] = []
try:
queue = device_mgr.get_agent_data_queue(user_id, run_id)
deadline = asyncio.get_event_loop().time() + _FILE_READ_TIMEOUT
while True:
remaining = deadline - asyncio.get_event_loop().time()
if remaining <= 0:
errors.append("Timed out waiting for file data from device")
break
try:
frame_data = await asyncio.wait_for(queue.get(), timeout=remaining)
except asyncio.TimeoutError:
errors.append("Timed out waiting for file data from device")
break
if frame_data is None:
# Sentinel from agent_complete — stream is done.
break
files.extend(frame_data.get("files", []))
except RuntimeError as exc:
errors.append(f"Queue error reading agent data: {exc}")
# ── 56. Extract + insert ─────────────────────────────────────────
items_processed = 0
items_created = 0
for file_info in files:
file_path: str = file_info.get("path", "<unknown>")
content: str = file_info.get("content", "")
if not content:
continue
items_processed += 1
try:
extracted = await _extract_items_from_content(
config.prompt_template, content, config.data_types
try:
# ── Phase 1: Triage ─────────────────────────────────────────
logger.info("agent_runner: run=%s phase=triage start user=%s", run_id, user_id)
last_run_str = "never (process all files)"
if config.last_run_at:
last_run_str = config.last_run_at.isoformat()
custom_section = ""
if config.prompt_template:
custom_section = f"User instructions:\n{config.prompt_template}"
file_ext_str = ", ".join(config.file_extensions) if config.file_extensions else "all"
triage_prompt = _TRIAGE_SYSTEM_PROMPT.format(
last_run_at=last_run_str,
custom_prompt_section=custom_section,
data_types=", ".join(config.data_types),
file_extensions=file_ext_str,
)
directory_paths = config.directory_paths
triage_user_msg = (
f"Explore these directories and produce the triage map:\n"
f"{json.dumps(directory_paths, ensure_ascii=False)}"
)
triage_tools: list[Any] = list(FILESYSTEM_TOOLS) + list(PROJECT_TOOLS)
triage_response = await _run_agent_with_tools(
system_prompt=triage_prompt,
user_message=triage_user_msg,
tools=triage_tools,
max_steps=_MAX_TRIAGE_STEPS,
)
triage_map = _parse_triage_map(triage_response)
if not triage_map:
errors.append(f"Triage phase failed to produce a valid file map: {triage_response[:500]}")
await _finalize_run(run_log, status="error", errors=errors)
return
logger.info(
"agent_runner: run=%s triage complete groups=%d total_files=%d",
run_id,
len(triage_map),
sum(len(files) for files in triage_map.values()),
)
# ── Phase 2: Processing (per group) ─────────────────────────
processing_tools = _build_processing_tools(config.data_types)
for group_key, file_paths in triage_map.items():
if not file_paths:
continue
logger.info(
"agent_runner: run=%s phase=processing group=%s files=%d",
run_id,
group_key,
len(file_paths),
)
except Exception as exc:
errors.append(f"LLM extraction error for {file_path!r}: {exc}")
continue
for item in extracted:
# Build project context for the LLM.
if group_key == "standalone":
project_context = "These files are not associated with any existing project."
else:
project_context = f"These files belong to project ID: {group_key}. Use this project_id when creating records."
file_list_str = "\n".join(f"- {fp}" for fp in file_paths)
processing_prompt = _PROCESSING_BASE_PROMPT.format(
data_types=", ".join(config.data_types),
project_context=project_context,
file_list=file_list_str,
custom_prompt_section=custom_section,
)
items_processed += len(file_paths)
try:
result = await _send_insert_to_client(
user_id, item["table"], item["data"], device_mgr
result_text = await _run_agent_with_tools(
system_prompt=processing_prompt,
user_message="Process the listed files now.",
tools=processing_tools,
max_steps=_MAX_PROCESSING_STEPS,
)
if result.get("error"):
errors.append(
f"Insert failed ({item['table']}, {file_path!r}): {result['error']}"
)
else:
items_created += 1
except asyncio.TimeoutError:
errors.append(
f"Timed out awaiting insert ack ({item['table']}, {file_path!r})"
logger.info(
"agent_runner: run=%s group=%s processing_result=%s",
run_id,
group_key,
result_text[:500],
)
# Count created items by scanning tool call results.
# The tools themselves handle creation; we estimate from the
# summary. A more precise count would require intercepting
# tool results, but the summary is sufficient for the run log.
except Exception as exc:
errors.append(f"Processing error for group '{group_key}': {exc}")
logger.error(
"agent_runner: run=%s group=%s processing failed: %s",
run_id,
group_key,
exc,
)
except RuntimeError as exc:
errors.append(f"Insert error ({item['table']}, {file_path!r}): {exc}")
# ── 7. Finalise ────────────────────────────────────────────────────
device_mgr.cleanup_agent_data_queue(user_id, run_id)
except Exception as exc:
errors.append(f"Agent run failed: {exc}")
logger.error("agent_runner: run=%s failed: %s", run_id, exc)
finally:
clear_client_executor()
if errors and items_created == 0:
# ── Finalise ────────────────────────────────────────────────────
if errors and items_processed == 0:
final_status = "error"
elif errors:
final_status = "partial"
@@ -369,19 +503,30 @@ async def run_local_agent(
items_processed=items_processed,
items_created=items_created,
errors=errors,
update_config_last_run=True,
update_config_last_run=False,
config_id=config.id,
config_type="local",
)
logger.info(
"agent_runner: run=%s done status=%s processed=%d created=%d errors=%d",
"agent_runner: run=%s done status=%s processed=%d errors=%d",
run_id,
final_status,
items_processed,
items_created,
len(errors),
)
# Notify the Electron client that the run is complete so it can close
# the run record in its local SQLite.
if run_context and device_mgr.is_online(user_id):
try:
await device_mgr.send_frame(user_id, {
"type": "run_complete",
"run_context": run_context,
"status": final_status,
})
except Exception as exc:
logger.warning("agent_runner: run=%s failed to send run_complete: %s", run_id, exc)
# ── Cloud agent runner ─────────────────────────────────────────────────────
@@ -405,8 +550,7 @@ async def run_cloud_agent(
3. Instantiate the provider client (Gmail or MS Graph).
4. Fetch messages/emails since ``config.last_run_at`` (or 7 days ago for
the first run) applying ``config.filter_config`` filters.
5. For each message/email call ``_extract_items_from_content`` with
``config.prompt_template`` to get structured ``{table, data}`` items.
5. For each message/email call the LLM to extract structured items.
6. Push each item to Electron as an ``insert`` tool-call.
7. If the provider refreshed its access token, re-encrypt and write it
back to ``config.oauth_token_encrypted``.
@@ -514,37 +658,40 @@ async def run_cloud_agent(
user_id,
)
# ── 56. Extract + insert ─────────────────────────────────────────
for msg in raw_messages:
content_text = msg.as_text
if not content_text:
continue
items_processed += 1
try:
extracted = await _extract_items_from_content(
config.prompt_template, content_text, config.data_types
)
except Exception as exc:
errors.append(f"LLM extraction error for message {msg.id!r}: {exc}")
continue
# ── 56. Extract + insert via LLM with tools ─────────────────────
executor = _make_agent_executor(user_id, device_mgr)
set_client_executor(executor)
try:
processing_tools = _build_processing_tools(config.data_types)
custom_section = ""
if config.prompt_template:
custom_section = f"User instructions:\n{config.prompt_template}"
for msg in raw_messages:
content_text = msg.as_text
if not content_text:
continue
items_processed += 1
processing_prompt = _PROCESSING_BASE_PROMPT.format(
data_types=", ".join(config.data_types),
project_context="Determine the appropriate project from the message context.",
file_list=f"Message from {config.provider} (id: {msg.id})",
custom_prompt_section=custom_section,
)
for item in extracted:
try:
result = await _send_insert_to_client(
user_id, item["table"], item["data"], device_mgr
await _run_agent_with_tools(
system_prompt=processing_prompt,
user_message=f"Process this message content:\n\n{content_text[:8000]}",
tools=processing_tools,
max_steps=_MAX_PROCESSING_STEPS,
)
if result.get("error"):
errors.append(
f"Insert failed ({item['table']}, msg={msg.id!r}): {result['error']}"
)
else:
items_created += 1
except asyncio.TimeoutError:
errors.append(
f"Timed out awaiting insert ack ({item['table']}, msg={msg.id!r})"
)
except RuntimeError as exc:
errors.append(f"Insert error ({item['table']}, msg={msg.id!r}): {exc}")
except Exception as exc:
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
finally:
clear_client_executor()
# ── 7. Persist refreshed token (if any) ───────────────────────────
refreshed = getattr(provider, "refreshed_credentials", None)
@@ -610,60 +757,11 @@ async def trigger_pending_runs(
* Runs execute **sequentially** to avoid flooding the WS connection.
"""
logger.info(
"agent_runner: scanning overdue runs for user=%s device=%s", user_id, device_id
"agent_runner: pending-run scan skipped for user=%s device=%s (client-owned agent config)",
user_id,
device_id,
)
async with async_session() as db:
local_result = await db.execute(
select(LocalAgentConfig).where(
LocalAgentConfig.user_id == user_id,
LocalAgentConfig.enabled == True, # noqa: E712
LocalAgentConfig.device_id == device_id,
)
)
local_configs: list[LocalAgentConfig] = list(local_result.scalars().all())
cloud_result = await db.execute(
select(CloudAgentConfig).where(
CloudAgentConfig.user_id == user_id,
CloudAgentConfig.enabled == True, # noqa: E712
)
)
cloud_configs: list[CloudAgentConfig] = list(cloud_result.scalars().all())
# Build ordered list of overdue (type, config) pairs.
pending: list[tuple[str, Any]] = []
for cfg in local_configs:
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
pending.append(("local", cfg))
for cfg in cloud_configs:
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
pending.append(("cloud", cfg))
if not pending:
logger.debug("agent_runner: no overdue runs for user=%s", user_id)
return
logger.info(
"agent_runner: %d overdue run(s) to dispatch for user=%s", len(pending), user_id
)
for agent_type, cfg in pending:
# Create a fresh run log for this scheduled dispatch.
run_log = AgentRunLog(
agent_id=cfg.id,
agent_type=agent_type,
user_id=user_id,
status="running",
)
async with async_session() as db:
db.add(run_log)
await db.commit()
await db.refresh(run_log)
if agent_type == "local":
await run_local_agent(user_id, cfg, run_log, device_mgr)
else:
await run_cloud_agent(user_id, cfg, run_log, device_mgr)
return
# ── Internal helper ─────────────────────────────────────────────────────────

File diff suppressed because it is too large Load Diff

View File

@@ -3,20 +3,15 @@
Maintains in-memory state for all active Electron → backend WebSocket
connections. One connection per user (latest replaces previous).
The manager participates in two interaction patterns:
The manager handles the **tool-call round-trip** pattern:
- Backend sends ``tool_call`` frame → Electron executes the action →
returns ``tool_result`` frame.
- ``create_pending_call`` registers a Future keyed by ``call_id``.
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
receive the result dict from Electron.
1. **Tool-call round-trip** (bidirectional CRUD):
- Backend sends ``tool_call`` frame → Electron executes CRUD → returns
``tool_result`` frame.
- ``create_pending_call`` registers a Future keyed by ``call_id``.
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
receive the result dict from Electron.
2. **Agent-data streaming** (local directory agent runs):
- Backend sends ``agent_run`` frame → Electron reads files and sends
back a stream of ``agent_data`` frames followed by ``agent_complete``.
- ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for
a specific ``run_id`` so the agent runner can iterate frames.
This pattern is used by all tools (CRUD, file-system, etc.) via
``execute_on_client()`` in ``ws_context.py``.
The ``device_manager`` module-level singleton is imported by both the
device WS route and the agent runner.
@@ -42,8 +37,6 @@ class DeviceConnection:
device_id: str
# Futures indexed by tool_call id — resolved when tool_result arrives.
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
# Per-run queues for agent_data / agent_complete frames.
agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict)
class DeviceConnectionManager:
@@ -153,31 +146,6 @@ class DeviceConnectionManager:
if fut is not None and not fut.done():
fut.set_result(result)
# ── Agent-data queue ──────────────────────────────────────────────
def get_agent_data_queue(
self, user_id: str, run_id: str
) -> asyncio.Queue[dict | None]:
"""Return (creating if absent) the queue for *run_id* agent frames.
The agent runner reads from this queue. The device WS handler writes
to it. ``None`` is the sentinel that signals the stream is finished.
"""
conn = self._connections.get(user_id)
if conn is None:
raise RuntimeError(
f"get_agent_data_queue: user {user_id!r} is not connected"
)
if run_id not in conn.agent_data_queues:
conn.agent_data_queues[run_id] = asyncio.Queue()
return conn.agent_data_queues[run_id]
def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None:
"""Remove the queue for *run_id* once a run has completed."""
conn = self._connections.get(user_id)
if conn:
conn.agent_data_queues.pop(run_id, None)
# Module-level singleton — import this everywhere.
device_manager = DeviceConnectionManager()

View File

@@ -1,6 +1,6 @@
"""LLM factory — centralised model instantiation via LiteLLM.
Every agent and the deep-agent supervisors call ``get_llm()`` or ``get_router_llm()``
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()``
instead of directly constructing a provider-specific class. The model string
follows the `LiteLLM model naming convention
<https://docs.litellm.ai/docs/providers>`_:
@@ -18,6 +18,7 @@ Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
from __future__ import annotations
import os
import warnings
from openai import AsyncOpenAI
import litellm
@@ -32,6 +33,14 @@ from app.config.settings import settings
# Drop them silently instead of raising UnsupportedParamsError.
litellm.drop_params = True
# Some provider responses include a plain dict in the `usage` field where a
# richer Pydantic model is expected. This warning is noisy but non-fatal.
warnings.filterwarnings(
"ignore",
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
category=UserWarning,
)
def _api_key_for_model(model: str) -> str | None:
"""Return the most appropriate API key for the given LiteLLM model string."""

View File

@@ -43,15 +43,21 @@ _PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
class MemoryMiddleware:
"""Enrich agent context with memory and persist interactions after."""
"""Enrich orchestrator context with memory and persist interactions after."""
def __init__(self, db: AsyncSession) -> None:
self._db = db
# ── Public API ────────────────────────────────────────────────────────────
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
"""Build memory context dict to inject into the agent before LLM call.
async def enrich_context(
self,
user_id: str,
message: str,
trace_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Build memory context dict to inject into the orchestrator before LLM call.
Returns a dict with keys:
core_memory — {key: plaintext_value, ...}
@@ -65,9 +71,21 @@ class MemoryMiddleware:
core = await self._load_core(user_id, fernet)
associative = await self._load_associative(user_id, message, fernet)
episodic = await self._load_episodic(user_id, fernet)
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
proactive = await self._load_proactive(user_id, fernet)
user_dbg = await self._get_user_debug(user_id)
logger.info(
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
trace_id or "-",
user_id,
user_dbg.get("tier") or "-",
len(core),
len(associative),
len(episodic),
len(proactive),
)
return {
"core_memory": core,
"associative_memory": associative,
@@ -81,6 +99,7 @@ class MemoryMiddleware:
session_id: str,
message: str,
response: str,
trace_id: str | None = None,
) -> None:
"""Summarise and store a completed interaction in episodic memory.
@@ -103,11 +122,19 @@ class MemoryMiddleware:
self._db.add(row)
try:
await self._db.commit()
user_dbg = await self._get_user_debug(user_id)
logger.info(
"memory: store_episode trace=%s user=%s tier=%s session=%s",
trace_id or "-",
user_id,
user_dbg.get("tier") or "-",
session_id,
)
except Exception as exc:
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def update_core(self, user_id: str, key: str, value: str) -> None:
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
"""Upsert a core memory key/value for a user."""
fernet = await self._get_fernet(user_id)
if fernet is None:
@@ -133,10 +160,176 @@ class MemoryMiddleware:
))
try:
await self._db.commit()
user_dbg = await self._get_user_debug(user_id)
logger.info(
"memory: update_core trace=%s user=%s tier=%s key=%s",
trace_id or "-",
user_id,
user_dbg.get("tier") or "-",
key,
)
except Exception as exc:
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
await self._db.rollback()
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
"""Return core memory as editable blocks (label/value)."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryCore)
.where(MemoryCore.user_id == user_id)
.order_by(MemoryCore.key.asc())
)
rows = result.scalars().all()
out: list[dict[str, str]] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.value_encrypted)
if plaintext is not None:
out.append({"label": row.key, "value": plaintext})
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
return out
async def get_core_block(self, user_id: str, label: str) -> str | None:
"""Return a single core memory block value by label."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return None
result = await self._db.execute(
select(MemoryCore).where(
MemoryCore.user_id == user_id,
MemoryCore.key == label,
)
)
row = result.scalar_one_or_none()
if row is None:
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
return None
value = _safe_decrypt(fernet, row.value_encrypted)
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
return value
async def delete_core(self, user_id: str, label: str) -> bool:
"""Delete a core memory block by label. Returns True if deleted."""
result = await self._db.execute(
select(MemoryCore).where(
MemoryCore.user_id == user_id,
MemoryCore.key == label,
)
)
row = result.scalar_one_or_none()
if row is None:
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
return False
await self._db.delete(row)
try:
await self._db.commit()
logger.info("memory: delete_core user=%s label=%s", user_id, label)
return True
except Exception as exc:
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
await self._db.rollback()
return False
async def append_core(self, user_id: str, label: str, content: str) -> None:
"""Append content to a core block, creating it if missing."""
current = await self.get_core_block(user_id, label)
if current is None:
await self.update_core(user_id, label, content)
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
return
await self.update_core(user_id, label, f"{current}\n{content}")
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
"""Replace one exact string inside a core block. Returns False if not found."""
current = await self.get_core_block(user_id, label)
if current is None or old not in current:
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
return False
await self.update_core(user_id, label, current.replace(old, new, 1))
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
return True
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
"""Insert a long-term archival memory entry."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return
encrypted = _encrypt(fernet, content)
row = MemoryAssociative(
id=str(uuid.uuid4()),
user_id=user_id,
content_encrypted=encrypted,
embedding=None,
entity_type=source,
entity_id=None,
)
self._db.add(row)
try:
await self._db.commit()
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
except Exception as exc:
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryAssociative)
.where(MemoryAssociative.user_id == user_id)
.order_by(MemoryAssociative.updated_at.desc())
.limit(100)
)
rows = result.scalars().all()
needle = query.strip().lower()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is None:
continue
if not needle or needle in plaintext.lower():
out.append(plaintext)
if len(out) >= max(top_k, 1):
break
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
return out
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
"""Search recall memory (episodic summaries) by keyword."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryEpisodic)
.where(MemoryEpisodic.user_id == user_id)
.order_by(MemoryEpisodic.created_at.desc())
.limit(100)
)
rows = result.scalars().all()
needle = query.strip().lower()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
if plaintext is None:
continue
if not needle or needle in plaintext.lower():
out.append(plaintext)
if len(out) >= max(top_k, 1):
break
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
return out
# ── Private helpers ───────────────────────────────────────────────────────
async def _get_fernet(self, user_id: str) -> Fernet | None:
@@ -148,6 +341,16 @@ class MemoryMiddleware:
return None
return Fernet(user.encryption_key.encode())
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
"""Load lightweight user debug fields for trace logs."""
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
return {"tier": None}
return {
"tier": user.tier,
}
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
result = await self._db.execute(
select(MemoryCore).where(MemoryCore.user_id == user_id)
@@ -183,10 +386,17 @@ class MemoryMiddleware:
out.append(plaintext)
return out
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
async def _load_episodic(
self,
user_id: str,
fernet: Fernet,
session_id: str | None = None,
) -> list[str]:
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
if session_id:
query = query.where(MemoryEpisodic.session_id == session_id)
result = await self._db.execute(
select(MemoryEpisodic)
.where(MemoryEpisodic.user_id == user_id)
query
.order_by(MemoryEpisodic.created_at.desc())
.limit(_EPISODIC_RECENT_N)
)

View File

@@ -1,141 +1,47 @@
"""Output Formatter — transforms deep-agent event streams into WS frame sequences.
Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``:
* ``("token", str)`` — supervisor text token
* ``("tool_end", dict)`` — sub-agent finished: ``{name, result}``
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
HomeFormatter:
* Streams text tokens as-is → emits ``WsStreamText``
(text may contain inline ``<type>[id,...]</type>`` entity tags
for the frontend to parse and render as interactive components)
* Attaches mutations → injects into ``WsStreamEnd``
FloatingFormatter:
* Sniffs first ``tool_end`` name → emits ``WsFloatingDomain``
* Streams text tokens → emits ``WsStreamText``
* Attaches mutations → injects into ``WsStreamEnd``
"""
"""Output formatter for deep-agent stream events."""
from __future__ import annotations
import logging
from collections.abc import AsyncGenerator
from typing import Any
from app.schemas import (
WsFloatingDomain,
WsStreamEnd,
WsStreamStart,
WsStreamText,
)
logger = logging.getLogger(__name__)
# Map sub-agent tool name → floating domain / entity type
_AGENT_DOMAIN: dict[str, str] = {
"task_agent": "tasks",
"timeline_agent": "timelines",
"note_agent": "notes",
"project_agent": "projects",
}
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
class HomeFormatter:
"""Consumes a deep-agent event stream and yields WS frames for the Home view.
Text tokens are forwarded as-is via ``WsStreamText``. The supervisor
embeds ``<type>[id1,id2]</type>`` entity tags inline — the frontend
is responsible for parsing those and rendering interactive components.
Mutations are attached to ``WsStreamEnd``.
"""
class StreamFormatter:
"""Convert `(event_type, data)` stream events into websocket frame models."""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
self._mutations: list[dict] = []
async def format(
self,
event_stream: AsyncGenerator[tuple[str, Any], None],
) -> AsyncGenerator[WsFrame, None]:
yield WsStreamStart(request_id=self.request_id)
started = False
async for event_type, data in event_stream:
if event_type == "token":
if data:
yield WsStreamText(request_id=self.request_id, chunk=data)
elif event_type == "mutations":
self._mutations = data or []
yield WsStreamEnd(
request_id=self.request_id,
mutations=[
{"action": m["action"], "table": m["table"], "data": m["data"]}
for m in self._mutations
],
)
class FloatingFormatter:
"""Consumes a deep-agent event stream and yields WS frames for the Floating view.
Sniffs the first ``tool_end`` event name to derive the domain (e.g.
``task_agent`` → ``"tasks"``), then streams text tokens as plain
``WsStreamText``. No block parsing for floating context.
"""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
self._mutations: list[dict] = []
async def format(
self,
event_stream: AsyncGenerator[tuple[str, Any], None],
) -> AsyncGenerator[WsFrame, None]:
domain_sent = False
async for event_type, data in event_stream:
if event_type == "tool_end" and not domain_sent:
# Sniff domain from the first sub-agent that completes
name = data.get("name", "")
domain = _AGENT_DOMAIN.get(name, "tasks")
yield WsFloatingDomain(
request_id=self.request_id,
domain=domain, # type: ignore[arg-type]
)
yield WsStreamStart(request_id=self.request_id)
domain_sent = True
elif event_type == "token":
if not domain_sent:
# First token arrived before any tool_end — default domain
if event_type == "floating_domain":
if isinstance(data, dict):
yield WsFloatingDomain(
request_id=self.request_id,
domain="tasks", # type: ignore[arg-type]
domain=data,
)
yield WsStreamStart(request_id=self.request_id)
domain_sent = True
if data:
yield WsStreamText(request_id=self.request_id, chunk=data)
continue
elif event_type == "mutations":
self._mutations = data or []
if event_type != "token":
continue
# If no events triggered domain_sent (edge case), still emit structure
if not domain_sent:
yield WsFloatingDomain(
request_id=self.request_id,
domain="tasks", # type: ignore[arg-type]
)
if not started:
yield WsStreamStart(request_id=self.request_id)
started = True
text = str(data or "")
if text:
yield WsStreamText(request_id=self.request_id, chunk=text)
if not started:
yield WsStreamStart(request_id=self.request_id)
yield WsStreamEnd(
request_id=self.request_id,
mutations=[
{"action": m["action"], "table": m["table"], "data": m["data"]}
for m in self._mutations
],
)
yield WsStreamEnd(request_id=self.request_id)

View File

@@ -7,21 +7,18 @@ The callback sends a `tool_call` WS frame and awaits the `tool_result`.
from __future__ import annotations
import logging
from contextvars import ContextVar
from typing import Any, Callable, Coroutine
from uuid import uuid4
logger = logging.getLogger(__name__)
# Holds the execute callback for the current WS session.
# Set by the chat WS handler before the deep agent runs; cleared after.
# Set by the chat WS handler before the orchestrator runs; cleared after.
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
"_client_executor"
)
# Optional collector that captures raw execute_on_client results.
# Set by the deep agent tool loop to capture CRUD mutations.
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
"_tool_result_collector", default=None
)
@@ -84,17 +81,12 @@ async def execute_on_client(
if limit is not None:
payload["limit"] = limit
logger.info("execute_on_client: sending payload action=%s table=%s id=%s", action, table, payload["id"])
result = await callback(payload)
if result is None:
logger.error("execute_on_client: callback returned None for action=%s table=%s id=%s", action, table, payload["id"])
else:
logger.info("execute_on_client: got result type=%s keys=%s", type(result).__name__, list(result.keys()) if isinstance(result, dict) else "N/A")
collector = _tool_result_collector.get(None)
if collector is not None and action in ("insert", "update", "delete"):
if collector is not None:
collector.append({
"action": action,
"table": table,
"data": data or {},
"data": result,
})
return result

View File

@@ -18,7 +18,9 @@ from app.config.settings import settings
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: initialise DB connection pool
# Startup: ensure agent tool modules are loaded.
import app.agents # noqa: F401
yield
# Shutdown: dispose SQLAlchemy connection pool
@@ -48,7 +50,7 @@ def create_app() -> FastAPI:
app.add_middleware(SanitizerMiddleware)
app.add_middleware(TierRateLimitMiddleware)
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
from app.api.routes import agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1")
@@ -58,7 +60,6 @@ def create_app() -> FastAPI:
app.include_router(plugins.router, prefix="/api/v1")
app.include_router(billing.router, prefix="/api/v1")
app.include_router(agents.router, prefix="/api/v1")
app.include_router(agent_setup.router, prefix="/api/v1")
app.include_router(device_ws.router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])

View File

@@ -142,9 +142,6 @@ class WsFrameType(str, Enum):
tool_result = "tool_result"
final = "final"
ping = "ping"
agent_run = "agent_run"
agent_data = "agent_data"
agent_complete = "agent_complete"
device_hello = "device_hello"
# ── v3 frame types ─────────────────────────────────────────────────
home_request = "home_request"
@@ -156,6 +153,10 @@ class WsFrameType(str, Enum):
data_request = "data_request"
data_response = "data_response"
mutation = "mutation"
# ── v4 journey frame types ────────────────────────────────────────
journey_start = "journey_start"
journey_message = "journey_message"
journey_reply = "journey_reply"
class WsToolCall(BaseModel):
@@ -208,31 +209,6 @@ class WsDeviceHello(BaseModel):
agent_ids: list[str] = Field(default_factory=list)
class WsAgentRun(BaseModel):
"""Server → Client: trigger an agent run on the connected device."""
type: Literal[WsFrameType.agent_run] = WsFrameType.agent_run
run_id: str
agent_id: str
config: dict[str, Any]
class WsAgentData(BaseModel):
"""Client → Server: files read by the local agent."""
type: Literal[WsFrameType.agent_data] = WsFrameType.agent_data
run_id: str
files: list[dict[str, Any]]
class WsAgentComplete(BaseModel):
"""Client → Server: Electron signals it has finished reading files."""
type: Literal[WsFrameType.agent_complete] = WsFrameType.agent_complete
run_id: str
files_read: int
errors: list[str] = Field(default_factory=list)
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
@@ -279,7 +255,14 @@ class WsStreamEnd(BaseModel):
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
request_id: str
mutations: list[dict[str, Any]] = Field(default_factory=list)
class WsDomain(BaseModel):
"""Structured floating domain payload for UI routing decisions."""
type: Literal["task", "timeline", "project", "node"]
id: str | None = None
section: Literal["task", "timeline", "note"] | None = None
class WsFloatingDomain(BaseModel):
@@ -287,7 +270,7 @@ class WsFloatingDomain(BaseModel):
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
request_id: str
domain: Literal["tasks", "timelines", "notes", "projects"]
domain: WsDomain
# ── Agent Catalog ─────────────────────────────────────────────────────
@@ -296,84 +279,28 @@ class AgentCatalogItem(BaseModel):
type: str
name: str
description: str
config_schema: dict[str, Any] = Field(default_factory=dict)
# ── Local Agent Config ────────────────────────────────────────────────
class LocalAgentConfigCreate(BaseModel):
name: str
device_id: str
directory_paths: list[str]
data_types: list[str]
prompt_template: str
file_extensions: list[str]
schedule_cron: str
class AgentCreationCheckRequest(BaseModel):
active_agents: int = Field(ge=0, default=0)
class LocalAgentConfigUpdate(BaseModel):
name: str | None = None
device_id: str | None = None
directory_paths: list[str] | None = None
data_types: list[str] | None = None
prompt_template: str | None = None
file_extensions: list[str] | None = None
schedule_cron: str | None = None
enabled: bool | None = None
class AgentCreationCheckResponse(BaseModel):
allowed: bool
tier: BillingTier
active_agents: int
limit: int
class LocalAgentConfigResponse(BaseModel):
id: str
name: str
device_id: str
directory_paths: list[str]
data_types: list[str]
prompt_template: str
file_extensions: list[str]
schedule_cron: str
enabled: bool
last_run_at: int | None
created_at: int
updated_at: int
# ── Cloud Agent Config ────────────────────────────────────────────────
class CloudAgentConfigCreate(BaseModel):
provider: Literal["gmail", "teams", "outlook"]
name: str
data_types: list[str]
prompt_template: str
oauth_token_encrypted: str
schedule_cron: str
filter_config: dict[str, Any] | None = None
class CloudAgentConfigUpdate(BaseModel):
provider: Literal["gmail", "teams", "outlook"] | None = None
name: str | None = None
data_types: list[str] | None = None
prompt_template: str | None = None
oauth_token_encrypted: str | None = None
schedule_cron: str | None = None
filter_config: dict[str, Any] | None = None
enabled: bool | None = None
class CloudAgentConfigResponse(BaseModel):
"""oauth_token_encrypted is intentionally excluded — never returned to clients."""
id: str
provider: Literal["gmail", "teams", "outlook"]
name: str
data_types: list[str]
prompt_template: str
schedule_cron: str
filter_config: dict[str, Any] | None
enabled: bool
last_run_at: int | None
created_at: int
updated_at: int
class AgentTriggerRequest(BaseModel):
directory: str = Field(min_length=1)
device_id: str = Field(default="")
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
what_to_extract: list[str] = Field(min_length=1)
actions_by_type: dict[str, list[str]] | None = None
batch_interval: str = Field(min_length=1)
custom_agent_prompt: str = Field(min_length=1)
active_agents: int = Field(ge=0, default=0)
# ── Agent Run Log ─────────────────────────────────────────────────────
@@ -392,18 +319,3 @@ class AgentRunLogResponse(BaseModel):
# ── Chatbot Journey ───────────────────────────────────────────────────
class JourneyStartRequest(BaseModel):
agent_type: Literal["local", "cloud"]
agent_id: str | None = None
class JourneyMessageRequest(BaseModel):
session_id: str
message: str
class JourneyResponse(BaseModel):
session_id: str
message: str
done: bool
prompt_template: str | None = None

View File

@@ -4,8 +4,6 @@ gunicorn>=22.0.0
langchain>=0.3.0
langchain-openai>=0.3.0
langchain-litellm>=0.1.0
langgraph>=0.3.0
deepagents>=0.4.10
litellm>=1.50.0
pydantic>=2.10.0
pydantic-settings>=2.7.0

View File

@@ -10,13 +10,13 @@ Coverage:
- run_local_agent — file-read timeout path
- run_local_agent — LLM extraction error path
- run_cloud_agent — stub returns error immediately
- trigger_pending_runs — overdue local + cloud dispatched
- trigger_pending_runs — skipped when config is client-owned
- trigger_pending_runs — non-overdue skipped
- trigger_pending_runs — device_id filter for local agents
Integration:
- POST /agents/{id}/run — 404 on unknown agent
- POST /agents/{id}/run — creates run log + dispatches background task
Integration:
- POST /agents/can-create — billing eligibility check
- POST /agents/trigger — creates run log + dispatches background task
"""
from __future__ import annotations
@@ -373,7 +373,7 @@ async def test_run_local_agent_happy_path():
assert kwargs["items_processed"] == 1
assert kwargs["items_created"] == 1
assert kwargs["errors"] == []
assert kwargs["update_config_last_run"] is True
assert kwargs["update_config_last_run"] is False
# Verify agent_run frame was sent.
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
@@ -690,31 +690,11 @@ async def test_finalize_run_updates_cloud_config_last_run_at():
@pytest.mark.asyncio
async def test_trigger_pending_runs_no_overdue():
"""If no agents are overdue trigger_pending_runs does nothing."""
from datetime import timedelta
config = _make_local_config()
config.last_run_at = datetime.now(timezone.utc) - timedelta(minutes=30) # ran 30m ago
config.schedule_cron = "0 */6 * * *" # every 6h — not due yet
mock_db_result_local = MagicMock()
mock_db_result_local.scalars.return_value.all.return_value = [config]
mock_db_result_cloud = MagicMock()
mock_db_result_cloud.scalars.return_value.all.return_value = []
"""Pending-run scan is skipped because agent config is client-owned."""
mgr = _make_manager()
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
mock_ctx.execute = AsyncMock(
side_effect=[mock_db_result_local, mock_db_result_cloud]
)
mock_session_factory.return_value = mock_ctx
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
mock_run.assert_not_called()
@@ -722,31 +702,11 @@ async def test_trigger_pending_runs_no_overdue():
@pytest.mark.asyncio
async def test_trigger_pending_runs_device_id_filter():
"""Local agents are only triggered for the matching device_id."""
# The DB query already filters by device_id, so we verify the SELECT
# includes the device_id filter by checking that a config bound to a
# different device is never dispatched.
#
# Since trigger_pending_runs queries with device_id == "dev-001",
# simulate the DB returning an empty list (as it would for a mismatch).
mock_db_result_local = MagicMock()
mock_db_result_local.scalars.return_value.all.return_value = [] # no match
mock_db_result_cloud = MagicMock()
mock_db_result_cloud.scalars.return_value.all.return_value = []
"""Device filtering is no longer backend-managed in pending runs."""
mgr = _make_manager(device_id="dev-001")
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
mock_ctx.execute = AsyncMock(
side_effect=[mock_db_result_local, mock_db_result_cloud]
)
mock_session_factory.return_value = mock_ctx
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
mock_run.assert_not_called()
@@ -754,56 +714,18 @@ async def test_trigger_pending_runs_device_id_filter():
@pytest.mark.asyncio
async def test_trigger_pending_runs_dispatches_overdue():
"""Overdue local agent triggers run_local_agent sequentially."""
config = _make_local_config() # last_run_at=None → always overdue
mock_db_result_local = MagicMock()
mock_db_result_local.scalars.return_value.all.return_value = [config]
mock_db_result_cloud = MagicMock()
mock_db_result_cloud.scalars.return_value.all.return_value = []
"""No pending runs are dispatched by backend after config deprecation."""
mgr = _make_manager()
call_order: list[str] = []
async def _mock_run_local(user_id, cfg, run_log, device_mgr):
call_order.append("run_local")
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
patch("app.core.agent_runner.run_local_agent", side_effect=_mock_run_local):
# First call: query configs. Subsequent calls: create run_log.
mock_query_ctx = AsyncMock()
mock_query_ctx.__aenter__ = AsyncMock(return_value=mock_query_ctx)
mock_query_ctx.__aexit__ = AsyncMock(return_value=False)
mock_query_ctx.execute = AsyncMock(
side_effect=[mock_db_result_local, mock_db_result_cloud]
)
run_log_obj = AgentRunLog(
id=str(uuid.uuid4()),
agent_id=config.id,
agent_type="local",
user_id=_FREE_UID,
status="running",
started_at=datetime.now(timezone.utc),
)
mock_insert_ctx = AsyncMock()
mock_insert_ctx.__aenter__ = AsyncMock(return_value=mock_insert_ctx)
mock_insert_ctx.__aexit__ = AsyncMock(return_value=False)
mock_insert_ctx.add = MagicMock()
mock_insert_ctx.commit = AsyncMock()
mock_insert_ctx.refresh = AsyncMock(side_effect=lambda obj: None)
mock_session_factory.side_effect = [mock_query_ctx, mock_insert_ctx]
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
assert call_order == ["run_local"]
mock_run.assert_not_called()
# ---------------------------------------------------------------------------
# Integration: POST /agents/{id}/run
# Integration: POST /agents/can-create and /agents/trigger
# ---------------------------------------------------------------------------
@@ -820,50 +742,67 @@ def _override_db(db_session):
@pytest.mark.asyncio
async def test_trigger_run_unknown_agent(client):
"""POST /agents/{id}/run returns 404 for unknown agent id."""
async def test_can_create_agent_allows_when_under_limit(client):
"""POST /agents/can-create returns allowed=True when under tier limit."""
resp = client.post(
f"/api/v1/agents/{uuid.uuid4()}/run",
headers=auth_header("power"),
"/api/v1/agents/can-create",
json={"active_agents": 0},
headers=auth_header("free"),
)
assert resp.status_code == 404
assert resp.status_code == 200
body = resp.json()
assert body["allowed"] is True
assert body["tier"] == "free"
assert body["active_agents"] == 0
assert body["limit"] == 2
@pytest.mark.asyncio
async def test_can_create_agent_denies_when_at_limit(client):
"""POST /agents/can-create returns allowed=False at free-tier limit."""
resp = client.post(
"/api/v1/agents/can-create",
json={"active_agents": 2},
headers=auth_header("free"),
)
assert resp.status_code == 200
body = resp.json()
assert body["allowed"] is False
assert body["limit"] == 2
@pytest.mark.asyncio
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
"""POST /agents/{id}/run creates a run log and dispatches a background task."""
# Create the local agent config in the DB.
config = LocalAgentConfig(
id=str(uuid.uuid4()),
user_id=TEST_USER_IDS["power"],
device_id="dev-001",
name="My Agent",
directory_paths=["/home/user/docs"],
data_types=["tasks"],
prompt_template="Extract tasks.",
file_extensions=[".txt"],
schedule_cron="0 */6 * * *",
enabled=True,
)
db_session.add(config)
await db_session.commit()
dispatched: list = []
"""POST /agents/trigger creates a local run log and dispatches background task."""
dispatched: list[tuple[str, str]] = []
async def _fake_run(user_id, cfg, run_log, device_mgr):
dispatched.append((user_id, cfg.id))
def _fake_create_task(coro):
coro.close()
return MagicMock()
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
patch("app.api.routes.agents.run_cloud_agent", new_callable=AsyncMock), \
patch("asyncio.create_task") as mock_create_task:
mock_create_task.side_effect = _fake_create_task
resp = client.post(
f"/api/v1/agents/{config.id}/run",
"/api/v1/agents/trigger",
json={
"directory": "/home/user/docs",
"what_to_extract": ["task", "note"],
"actions_by_type": {"task": ["add", "update"], "note": ["add"]},
"batch_interval": "0 */6 * * *",
"custom_agent_prompt": "Extract tasks and notes.",
"active_agents": 0,
},
headers=auth_header("power"),
)
assert resp.status_code == 202
data = resp.json()
assert data["agent_id"] == config.id
assert isinstance(data["agent_id"], str)
assert data["agent_id"]
assert data["status"] == "running"
assert data["agent_type"] == "local"

288
tests/test_deep_agent.py Normal file
View File

@@ -0,0 +1,288 @@
"""Unit tests for single-agent deep_agent flows with mocked tool results."""
from __future__ import annotations
from datetime import date, timedelta
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from langchain_core.messages import AIMessage, ToolMessage
from app.core.deep_agent import (
_infer_floating_domain,
_normalize_tagged_list_lines,
run_floating,
run_floating_stream,
run_home,
)
class _FakeTool:
name = "list_tasks"
async def ainvoke(self, args):
return {"rows": [{"id": "task-1", "title": "Mock Task"}], "echo": args}
class _FakeLLM:
def __init__(self) -> None:
self.agent_calls = 0
def bind_tools(self, _tools):
return self
async def ainvoke(self, messages):
system_prompt = str(getattr(messages[0], "content", "")) if messages else ""
if "strict domain classifier" in system_prompt:
return AIMessage(content='{"type":"timeline","id":"tl-1","section":null}')
self.agent_calls += 1
if self.agent_calls == 1:
return AIMessage(
content="",
tool_calls=[
{
"id": "call-1",
"name": "list_tasks",
"args": {"project_id": "proj-1"},
}
],
)
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
assert tool_messages, "Expected at least one tool message"
return AIMessage(content=f"Final answer from mocked tool: {tool_messages[-1].content}")
async def astream(self, _messages):
yield SimpleNamespace(content="stream-")
yield SimpleNamespace(content="ok")
@pytest.mark.asyncio
async def test_run_home_uses_mocked_tool_result():
fake_llm = _FakeLLM()
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
):
out = await run_home("user-1", "list my tasks", {})
assert "Final answer from mocked tool" in out
assert "Mock Task" in out
@pytest.mark.asyncio
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
fake_llm = _FakeLLM()
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
):
events = []
async for event in run_floating_stream(
"user-1",
"show me timeline updates",
{"scope": {"type": "timeline", "id": "tl-1"}},
):
events.append(event)
assert events[0] == (
"floating_domain",
{"type": "timeline", "id": "tl-1", "section": None},
)
assert ("token", "stream-") in events
assert ("token", "ok") in events
@pytest.mark.asyncio
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
class _ClassifierOnlyLLM:
async def ainvoke(self, _messages):
return AIMessage(
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
)
with patch("app.core.deep_agent.get_llm", return_value=_ClassifierOnlyLLM()):
domain = await _infer_floating_domain(
"Quali sono i miei task per il progetto X",
{
"scope": {"type": "timeline"},
"resolved_project_id": "213213-312321-312312-421321",
},
)
assert domain == {
"type": "project",
"id": "213213-312321-312312-421321",
"section": "task",
}
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
raw = (
"Certo!\n\n"
"1. **Task A** — priorita high <task>[task-1]</task>\n"
"2. **Task B** — priorita medium <task>[task-2]</task>\n"
)
out = _normalize_tagged_list_lines(raw, "quali sono le prossime attivita?")
assert "<task>[task-1]</task>" in out
assert "<task>[task-2]</task>" in out
assert "Task A" not in out
assert "Task B" not in out
def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_month_future_only():
today = date.today()
tomorrow = today + timedelta(days=1)
yesterday = today - timedelta(days=1)
next_month = (today.replace(day=28) + timedelta(days=5)).replace(day=1)
raw = "\n".join(
[
f"- Milestone old — {yesterday.strftime('%d/%m/%Y')} <timeline>[tl-old]</timeline>",
f"- Milestone next — {tomorrow.strftime('%d/%m/%Y')} <timeline>[tl-next]</timeline>",
f"- Milestone future — {next_month.strftime('%d/%m/%Y')} <timeline>[tl-future]</timeline>",
]
)
out = _normalize_tagged_list_lines(raw, "invece i miei eventi prossimi?")
assert "<timeline>[tl-next]</timeline>" in out
assert "<timeline>[tl-old]</timeline>" not in out
assert "<timeline>[tl-future]</timeline>" not in out
@pytest.mark.asyncio
async def test_run_floating_strips_xml_like_tags_from_final_text():
fake_llm = _FakeLLM()
async def _fake_run_single_agent(**_kwargs):
return (
"Hai 1 task:\\n"
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
)
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
):
text, _domain = await run_floating(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
)
assert "<task>" not in text
assert "</task>" not in text
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text
@pytest.mark.asyncio
async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
fake_llm = _FakeLLM()
async def _fake_stream(**_kwargs):
yield "token", "Hai 1 task:\\n"
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
):
events = []
async for event in run_floating_stream(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
):
events.append(event)
token_events = [str(data) for event_type, data in events if event_type == "token"]
combined = "".join(token_events)
assert "<task>" not in combined
assert "</task>" not in combined
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined
@pytest.mark.asyncio
async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty():
class _NoChunkLLM:
def __init__(self) -> None:
self.calls = 0
def bind_tools(self, _tools):
return self
async def ainvoke(self, _messages):
self.calls += 1
if self.calls == 1:
return AIMessage(
content="",
tool_calls=[
{
"id": "call-1",
"name": "list_tasks",
"args": {},
}
],
)
return AIMessage(content="No notes found.")
async def astream(self, _messages):
if False:
yield None
with patch("app.core.deep_agent.get_llm", return_value=_NoChunkLLM()), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
):
events = []
async for event in run_floating_stream(
"user-1",
"quali sono le note?",
{"scope": {"type": "note"}},
):
events.append(event)
assert events[0][0] == "floating_domain"
assert ("token", "No notes found.") in events
@pytest.mark.asyncio
async def test_run_floating_returns_fallback_when_sanitization_would_empty_text():
fake_llm = _FakeLLM()
async def _fake_run_single_agent(**_kwargs):
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
):
text, _domain = await run_floating(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
)
assert text == "No results found."
@pytest.mark.asyncio
async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text():
fake_llm = _FakeLLM()
async def _fake_stream(**_kwargs):
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
):
events = []
async for event in run_floating_stream(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
):
events.append(event)
assert ("token", "No results found.") in events

View File

@@ -110,6 +110,32 @@ async def test_enrich_context_returns_episodic_memory(db_session, user_with_key)
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
@pytest.mark.asyncio
async def test_enrich_context_filters_episodic_by_session_id(db_session, user_with_key):
target_session = str(uuid.uuid4())
other_session = str(uuid.uuid4())
db_session.add(MemoryEpisodic(
id=str(uuid.uuid4()),
user_id=USER_ID,
summary_encrypted=_enc("Target session memory"),
session_id=target_session,
))
db_session.add(MemoryEpisodic(
id=str(uuid.uuid4()),
user_id=USER_ID,
summary_encrypted=_enc("Other session memory"),
session_id=other_session,
))
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "any message", session_id=target_session)
episodic = ctx.get("episodic_memory", [])
assert any("Target session" in s for s in episodic)
assert not any("Other session" in s for s in episodic)
@pytest.mark.asyncio
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
# Add one pattern above threshold and one below
@@ -229,6 +255,40 @@ async def test_update_core_upsert(db_session, user_with_key):
assert _dec(rows[0].value_encrypted) == "fr"
@pytest.mark.asyncio
async def test_core_block_edit_ops(db_session, user_with_key):
middleware = MemoryMiddleware(db_session)
await middleware.update_core(USER_ID, "human", "Name: Roberto")
await middleware.append_core(USER_ID, "human", "Timezone: Europe/Rome")
replaced = await middleware.replace_core(USER_ID, "human", "Roberto", "Robert")
blocks = await middleware.list_core_blocks(USER_ID)
human = next(b for b in blocks if b["label"] == "human")
assert replaced is True
assert "Name: Robert" in human["value"]
assert "Timezone: Europe/Rome" in human["value"]
deleted = await middleware.delete_core(USER_ID, "human")
assert deleted is True
assert await middleware.get_core_block(USER_ID, "human") is None
@pytest.mark.asyncio
async def test_archival_and_recall_search_helpers(db_session, user_with_key):
middleware = MemoryMiddleware(db_session)
await middleware.insert_archival(USER_ID, "Project whitelist has release risk", source="assistant")
await middleware.store_episode(USER_ID, str(uuid.uuid4()), "How is whitelist?", "Whitelist is delayed")
arch = await middleware.search_archival(USER_ID, "whitelist", top_k=3)
rec = await middleware.search_recall(USER_ID, "delayed", top_k=3)
assert any("whitelist" in item.lower() for item in arch)
assert any("delayed" in item.lower() for item in rec)
# ── End-to-end WS: memory middleware is called during home_request ────────────
def test_home_request_calls_memory_middleware(client):
@@ -240,21 +300,20 @@ def test_home_request_calls_memory_middleware(client):
def __init__(self, db):
pass
async def enrich_context(self, user_id, message):
async def enrich_context(self, user_id, message, **kwargs):
enrich_calls.append((user_id, message))
return {"core_memory": {"tz": "UTC"}}
async def store_episode(self, user_id, session_id, message, response):
async def store_episode(self, user_id, session_id, message, response, **kwargs):
store_calls.append((user_id, session_id, message, response))
token = make_jwt("power", user_id=USER_ID)
session_id = str(uuid.uuid4())
async def _mock_stream(user_id, message, context, db_session_factory=None):
async def _mock_stream(user_id, message, context):
# Verify memory context was injected
assert context.get("core_memory") == {"tz": "UTC"}
yield ("token", "Done")
yield ("mutations", [])
yield "token", "Done"
with (
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),

View File

@@ -1,214 +1,82 @@
"""Tests for app.core.output_formatter — HomeFormatter and FloatingFormatter."""
"""Tests for app.core.output_formatter.StreamFormatter."""
from __future__ import annotations
import pytest
from app.core.output_formatter import HomeFormatter, FloatingFormatter
from app.schemas import (
WsFloatingDomain,
WsStreamEnd,
WsStreamStart,
WsStreamText,
)
from app.core.output_formatter import StreamFormatter
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
# ── helpers ───────────────────────────────────────────────────────────────────
async def _stream(*events: tuple[str, object]):
"""Async generator that yields (event_type, data) tuples."""
for event in events:
yield event
async def collect(formatter, event_stream):
async def _collect(formatter: StreamFormatter, event_stream):
frames = []
async for frame in formatter.format(event_stream):
frames.append(frame)
return frames
# ── HomeFormatter ─────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_home_formatter_plain_text():
req_id = "req-1"
events = [
("token", "Hello world"),
("mutations", []),
]
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
async def test_stream_formatter_text_stream() -> None:
formatter = StreamFormatter(request_id="req-1")
frames = await _collect(
formatter,
_stream(("token", "Hello"), ("token", " world")),
)
assert isinstance(frames[0], WsStreamStart)
assert frames[0].request_id == req_id
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
assert any("Hello world" in f.chunk for f in text_frames)
assert isinstance(frames[1], WsStreamText)
assert frames[1].chunk == "Hello"
assert isinstance(frames[2], WsStreamText)
assert frames[2].chunk == " world"
assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio
async def test_home_formatter_entity_tags_passed_through():
"""Entity tags are streamed as-is — the frontend parses them."""
req_id = "req-2"
events = [
("token", "Here is your project:\n<project>[abc-123]</project>\nAll good."),
("mutations", []),
]
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
async def test_stream_formatter_floating_domain_first() -> None:
formatter = StreamFormatter(request_id="req-2")
frames = await _collect(
formatter,
_stream(
(
"floating_domain",
{"type": "node", "id": "n-1", "section": None},
),
("token", "Summary"),
),
)
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
assert "<project>[abc-123]</project>" in text
assert "Here is your project:" in text
assert "All good." in text
@pytest.mark.asyncio
async def test_home_formatter_multiple_tags_passed_through():
req_id = "req-3"
events = [
("token", "<project>[p1]</project>\nText\n<task>[t1,t2]</task>"),
("mutations", []),
]
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
assert "<project>[p1]</project>" in text
assert "<task>[t1,t2]</task>" in text
@pytest.mark.asyncio
async def test_home_formatter_tool_end_ignored():
"""tool_end events are silently ignored by HomeFormatter."""
req_id = "req-4"
events = [
("tool_end", {"name": "task_agent", "result": "3 tasks"}),
("token", "No tags here."),
("mutations", []),
]
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
assert text == "No tags here."
@pytest.mark.asyncio
async def test_home_formatter_mutations_in_stream_end():
req_id = "req-5"
muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}]
events = [
("token", "Done"),
("mutations", muts),
]
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
end_frame = frames[-1]
assert isinstance(end_frame, WsStreamEnd)
assert len(end_frame.mutations) == 1
assert end_frame.mutations[0]["action"] == "insert"
@pytest.mark.asyncio
async def test_home_formatter_frame_order():
"""stream_start is first, stream_end is last."""
req_id = "req-6"
formatter = HomeFormatter(request_id=req_id)
frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", [])))
assert isinstance(frames[0], WsStreamStart)
assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain.type == "node"
assert frames[0].domain.id == "n-1"
assert isinstance(frames[1], WsStreamStart)
assert isinstance(frames[2], WsStreamText)
assert frames[2].chunk == "Summary"
assert isinstance(frames[-1], WsStreamEnd)
# ── FloatingFormatter ─────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_floating_formatter_domain_from_tool_end():
req_id = "pop-1"
formatter = FloatingFormatter(request_id=req_id)
events = [
("tool_end", {"name": "task_agent", "result": "ok"}),
("token", "Hello"),
("mutations", []),
]
frames = await collect(formatter, _stream(*events))
async def test_stream_formatter_ignores_unknown_events() -> None:
formatter = StreamFormatter(request_id="req-3")
frames = await _collect(
formatter,
_stream(("tool_end", {"name": "x"}), ("token", "ok")),
)
assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain == "tasks"
assert frames[0].request_id == req_id
@pytest.mark.asyncio
async def test_floating_formatter_text_only():
req_id = "pop-2"
formatter = FloatingFormatter(request_id=req_id)
events = [
("tool_end", {"name": "timeline_agent", "result": "done"}),
("token", "Summary"),
("mutations", []),
]
frames = await collect(formatter, _stream(*events))
assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain == "timelines"
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
assert len(text_frames) == 1
assert text_frames[0].chunk == "Summary"
assert text_frames[0].chunk == "ok"
@pytest.mark.asyncio
async def test_floating_formatter_no_entity_tags():
"""FloatingFormatter never emits entity tag blocks."""
req_id = "pop-3"
formatter = FloatingFormatter(request_id=req_id)
events = [
("tool_end", {"name": "note_agent", "result": "data"}),
("token", "some text"),
("mutations", []),
]
frames = await collect(formatter, _stream(*events))
# Only expected frame types
for f in frames:
assert isinstance(f, (WsFloatingDomain, WsStreamStart, WsStreamText, WsStreamEnd))
async def test_stream_formatter_empty_stream_still_brackets() -> None:
formatter = StreamFormatter(request_id="req-4")
frames = await _collect(formatter, _stream())
@pytest.mark.asyncio
async def test_floating_formatter_end_frame():
req_id = "pop-4"
formatter = FloatingFormatter(request_id=req_id)
events = [
("tool_end", {"name": "project_agent", "result": "ok"}),
("token", "Done"),
("mutations", []),
]
frames = await collect(formatter, _stream(*events))
assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio
async def test_floating_formatter_default_domain_on_early_token():
"""When the first event is a token (no tool_end yet), default to 'tasks'."""
req_id = "pop-5"
formatter = FloatingFormatter(request_id=req_id)
events = [("token", "hi"), ("mutations", [])]
frames = await collect(formatter, _stream(*events))
assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain == "tasks"
@pytest.mark.asyncio
async def test_floating_formatter_mutations_in_stream_end():
req_id = "pop-6"
muts = [{"action": "update", "table": "tasks", "data": {"id": "t2"}}]
events = [
("token", "Updated"),
("mutations", muts),
]
formatter = FloatingFormatter(request_id=req_id)
frames = await collect(formatter, _stream(*events))
end_frame = frames[-1]
assert isinstance(end_frame, WsStreamEnd)
assert len(end_frame.mutations) == 1
assert len(frames) == 2
assert isinstance(frames[0], WsStreamStart)
assert isinstance(frames[1], WsStreamEnd)

View File

@@ -88,7 +88,7 @@ class TestPluginRegistry:
async def test_list_filter_by_query(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
result = await reg.list_plugins(db_session, query="time tracker")
result = await reg.list_plugins(db_session, query="time")
assert result.total == 1
assert result.plugins[0].id == "plugin-time-tracker"

View File

@@ -4,6 +4,7 @@ import pytest
from pydantic import ValidationError
from app.schemas import (
WsDomain,
WsFrameType,
WsHomeRequest,
WsFloatingDomain,
@@ -178,23 +179,15 @@ def test_stream_text_deserializes():
def test_stream_end_defaults():
frame = WsStreamEnd(request_id="r1")
assert frame.type == WsFrameType.stream_end
assert frame.mutations == []
def test_stream_end_with_mutations():
mutations = [{"action": "create", "table": "tasks", "data": {"title": "New task"}}]
frame = WsStreamEnd(request_id="r1", mutations=mutations)
assert len(frame.mutations) == 1
assert frame.mutations[0]["action"] == "create"
def test_stream_end_serializes():
data = WsStreamEnd(request_id="r2").model_dump()
assert data == {"type": "stream_end", "request_id": "r2", "mutations": []}
assert data == {"type": "stream_end", "request_id": "r2"}
def test_stream_end_deserializes():
raw = {"type": "stream_end", "request_id": "r3", "mutations": []}
raw = {"type": "stream_end", "request_id": "r3"}
frame = WsStreamEnd.model_validate(raw)
assert frame.request_id == "r3"
@@ -203,28 +196,47 @@ def test_stream_end_deserializes():
def test_floating_domain_tasks():
frame = WsFloatingDomain(request_id="r1", domain="tasks")
frame = WsFloatingDomain(request_id="r1", domain=WsDomain(type="task"))
assert frame.type == WsFrameType.floating_domain
assert frame.domain == "tasks"
assert frame.domain.type == "task"
@pytest.mark.parametrize("domain", ["tasks", "timelines", "notes", "projects"])
def test_floating_domain_valid_domains(domain: str):
frame = WsFloatingDomain(request_id="r1", domain=domain) # type: ignore[arg-type]
assert frame.domain == domain
def test_floating_domain_valid_domains():
frame = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="project", id="213213-312321-312312-421321", section="task"),
)
assert frame.domain.type == "project"
assert frame.domain.id == "213213-312321-312312-421321"
assert frame.domain.section == "task"
def test_floating_domain_invalid():
with pytest.raises(ValidationError):
WsFloatingDomain(request_id="r1", domain="invalid") # type: ignore[arg-type]
def test_floating_domain_object_valid():
frame = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="project", id="p1", section="task"),
)
assert frame.domain.type == "project"
def test_floating_domain_serializes():
d = WsFloatingDomain(request_id="r1", domain="notes").model_dump()
assert d == {"type": "floating_domain", "request_id": "r1", "domain": "notes"}
d = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="timeline"),
).model_dump()
assert d == {
"type": "floating_domain",
"request_id": "r1",
"domain": {"type": "timeline", "id": None, "section": None},
}
def test_floating_domain_deserializes():
raw = {"type": "floating_domain", "request_id": "r1", "domain": "projects"}
raw = {
"type": "floating_domain",
"request_id": "r1",
"domain": {"type": "node", "id": "n-1", "section": None},
}
frame = WsFloatingDomain.model_validate(raw)
assert frame.domain == "projects"
assert frame.domain.type == "node"
assert frame.domain.id == "n-1"

View File

@@ -45,15 +45,13 @@ def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
return frames
async def _mock_home_stream(user_id, message, context, db_session_factory=None):
yield "token", "Here are your tasks:\n<task>[t1,t2]</task>"
yield "mutations", []
async def _mock_home_stream(user_id, message, context):
yield "token", "Hello"
async def _mock_floating_stream(user_id, message, context, scope=None, db_session_factory=None):
yield "tool_end", {"name": "task_agent", "result": "ok"}
async def _mock_floating_stream(user_id, message, context):
yield "floating_domain", {"type": "task", "id": None, "section": None}
yield "token", "Here is a summary"
yield "mutations", []
# ── tests ─────────────────────────────────────────────────────────────────────
@@ -104,7 +102,7 @@ def test_floating_request_produces_domain_frame(client):
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
assert domain_frame["domain"] == "tasks"
assert domain_frame["domain"]["type"] == "task"
assert domain_frame["request_id"] == "p1"
@@ -113,9 +111,8 @@ def test_home_request_request_id_propagated(client):
token = make_jwt("power", user_id=USER_ID)
req_id = "my-unique-req-id"
async def _stream(user_id, message, context, db_session_factory=None):
async def _stream(user_id, message, context):
yield "token", "ok"
yield "mutations", []
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: