From 2b7d302ef28397e797d23ef2fd525d63fa29c2b1 Mon Sep 17 00:00:00 2001 From: Roberto Musso Date: Mon, 6 Apr 2026 23:44:12 +0200 Subject: [PATCH] refactor: remove monolith app/, Dockerfile, requirements.txt All business logic has been extracted into microservices: - services/auth/ (Step 1) - services/ws-gateway/ (Step 2) - services/chat/ (Step 2) - services/batch-agent/ (Step 3) - services/billing/ (Step 4) Shared code lives in shared/. Migrations remain in alembic/. Tests in tests/ will need updating to target individual services. --- Dockerfile | 39 - app/__init__.py | 0 app/agents/__init__.py | 5 - app/agents/filesystem_agent.py | 85 --- app/agents/note_agent.py | 139 ---- app/agents/project_agent.py | 143 ---- app/agents/task_agent.py | 238 ------- app/agents/timeline_agent.py | 114 --- app/api/__init__.py | 0 app/api/deps.py | 14 - app/api/middleware/__init__.py | 19 - app/api/middleware/auth.py | 80 --- app/api/middleware/rate_limit.py | 129 ---- app/api/middleware/sanitizer.py | 139 ---- app/api/routes/__init__.py | 0 app/api/routes/agent_setup.py | 406 ----------- app/api/routes/agents.py | 222 ------ app/api/routes/auth.py | 235 ------ app/api/routes/backup.py | 171 ----- app/api/routes/billing.py | 85 --- app/api/routes/chat.py | 29 - app/api/routes/device_ws.py | 417 ----------- app/api/routes/plugins.py | 148 ---- app/api/routes/storage.py | 195 ----- app/api/routes/vectors.py | 79 --- app/billing/__init__.py | 4 - app/billing/stripe_service.py | 256 ------- app/billing/tier_manager.py | 195 ----- app/config/__init__.py | 0 app/config/settings.py | 62 -- app/core/__init__.py | 0 app/core/agent_registry.py | 30 - app/core/agent_runner.py | 1064 ---------------------------- app/core/deep_agent.py | 846 ---------------------- app/core/device_manager.py | 151 ---- app/core/llm.py | 122 ---- app/core/memory_middleware.py | 441 ------------ app/core/output_formatter.py | 47 -- app/core/ws_context.py | 92 --- app/db.py | 40 -- app/integrations/__init__.py | 164 ----- app/integrations/gmail.py | 335 --------- app/integrations/ms_graph.py | 352 --------- app/main.py | 72 -- app/marketplace/__init__.py | 7 - app/marketplace/plugin_registry.py | 212 ------ app/marketplace/plugin_review.py | 125 ---- app/marketplace/revenue_share.py | 233 ------ app/models.py | 476 ------------- app/schemas.py | 321 --------- app/storage/__init__.py | 1 - app/storage/blob_store.py | 106 --- app/storage/encryption.py | 32 - app/storage/vector_store.py | 205 ------ requirements.txt | 37 - 55 files changed, 9159 deletions(-) delete mode 100644 Dockerfile delete mode 100644 app/__init__.py delete mode 100644 app/agents/__init__.py delete mode 100644 app/agents/filesystem_agent.py delete mode 100644 app/agents/note_agent.py delete mode 100644 app/agents/project_agent.py delete mode 100644 app/agents/task_agent.py delete mode 100644 app/agents/timeline_agent.py delete mode 100644 app/api/__init__.py delete mode 100644 app/api/deps.py delete mode 100644 app/api/middleware/__init__.py delete mode 100644 app/api/middleware/auth.py delete mode 100644 app/api/middleware/rate_limit.py delete mode 100644 app/api/middleware/sanitizer.py delete mode 100644 app/api/routes/__init__.py delete mode 100644 app/api/routes/agent_setup.py delete mode 100644 app/api/routes/agents.py delete mode 100644 app/api/routes/auth.py delete mode 100644 app/api/routes/backup.py delete mode 100644 app/api/routes/billing.py delete mode 100644 app/api/routes/chat.py delete mode 100644 app/api/routes/device_ws.py delete mode 100644 app/api/routes/plugins.py delete mode 100644 app/api/routes/storage.py delete mode 100644 app/api/routes/vectors.py delete mode 100644 app/billing/__init__.py delete mode 100644 app/billing/stripe_service.py delete mode 100644 app/billing/tier_manager.py delete mode 100644 app/config/__init__.py delete mode 100644 app/config/settings.py delete mode 100644 app/core/__init__.py delete mode 100644 app/core/agent_registry.py delete mode 100644 app/core/agent_runner.py delete mode 100644 app/core/deep_agent.py delete mode 100644 app/core/device_manager.py delete mode 100644 app/core/llm.py delete mode 100644 app/core/memory_middleware.py delete mode 100644 app/core/output_formatter.py delete mode 100644 app/core/ws_context.py delete mode 100644 app/db.py delete mode 100644 app/integrations/__init__.py delete mode 100644 app/integrations/gmail.py delete mode 100644 app/integrations/ms_graph.py delete mode 100644 app/main.py delete mode 100644 app/marketplace/__init__.py delete mode 100644 app/marketplace/plugin_registry.py delete mode 100644 app/marketplace/plugin_review.py delete mode 100644 app/marketplace/revenue_share.py delete mode 100644 app/models.py delete mode 100644 app/schemas.py delete mode 100644 app/storage/__init__.py delete mode 100644 app/storage/blob_store.py delete mode 100644 app/storage/encryption.py delete mode 100644 app/storage/vector_store.py delete mode 100644 requirements.txt diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 32496db..0000000 --- a/Dockerfile +++ /dev/null @@ -1,39 +0,0 @@ -# ── builder ────────────────────────────────────────────────────────────────── -FROM python:3.12-slim AS builder - -WORKDIR /build - -COPY requirements.txt . -RUN pip install --upgrade pip && \ - pip install --no-cache-dir --prefix=/install -r requirements.txt - -# ── runtime ────────────────────────────────────────────────────────────────── -FROM python:3.12-slim AS runtime - -# Non-root user -RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser - -WORKDIR /app - -# Copy installed packages from builder -COPY --from=builder /install /usr/local - -# Copy application source -COPY app/ app/ - -# Copy Alembic migration files -COPY alembic/ alembic/ -COPY alembic.ini . - -# Ensure appuser owns the working directory -RUN chown -R appuser:appgroup /app - -USER appuser - -EXPOSE 8000 - -CMD ["gunicorn", "app.main:app", \ - "-k", "uvicorn.workers.UvicornWorker", \ - "--bind", "0.0.0.0:8000", \ - "--workers", "4", \ - "--timeout", "120"] diff --git a/app/__init__.py b/app/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/app/agents/__init__.py b/app/agents/__init__.py deleted file mode 100644 index a2dc4c6..0000000 --- a/app/agents/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Expose tool modules used by deep orchestrator-worker graphs.""" - -from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent - -__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"] diff --git a/app/agents/filesystem_agent.py b/app/agents/filesystem_agent.py deleted file mode 100644 index 8e6018c..0000000 --- a/app/agents/filesystem_agent.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Filesystem agent — tools for reading local directories and files on Electron. - -These tools delegate to the Electron client via ``execute_on_client()`` using -the same WS tool-call round-trip pattern as CRUD tools. The Electron app -handles actual disk I/O and responds with ``tool_result`` frames. -""" - -from __future__ import annotations - -from typing import Any - -from langchain_core.tools import tool - -from app.core.ws_context import execute_on_client - - -@tool -async def list_directory(path: str) -> str: - """List files and folders in a local directory on the user's device. - - Returns a formatted listing of entries with name, type (file/directory), - and full path. - """ - result = await execute_on_client( - action="list_directory", - data={"path": path}, - ) - entries: list[dict[str, Any]] = result.get("entries", []) - if not entries: - return f"Directory '{path}' is empty or does not exist." - lines: list[str] = [] - for entry in entries: - entry_type = entry.get("type", "unknown") - entry_name = entry.get("name", "") - entry_path = entry.get("path", "") - lines.append(f"- [{entry_type}] {entry_name} ({entry_path})") - return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines) - - -@tool -async def read_file_content(path: str) -> str: - """Read the text content of a local file on the user's device. - - Returns the file content as a string. Large files may be truncated - by the Electron client. - """ - result = await execute_on_client( - action="read_file_content", - data={"path": path}, - ) - content: str = result.get("content", "") - if not content: - return f"File '{path}' is empty or could not be read." - return content - - -@tool -async def get_file_metadata(path: str) -> str: - """Get metadata for a local file: size, creation date, modification date, extension. - - Returns a formatted summary of the file's metadata. - """ - result = await execute_on_client( - action="get_file_metadata", - data={"path": path}, - ) - size = result.get("size", "unknown") - created = result.get("createdAt", "unknown") - modified = result.get("modifiedAt", "unknown") - extension = result.get("extension", "unknown") - name = result.get("name", path) - return ( - f"File: {name}\n" - f" Extension: {extension}\n" - f" Size: {size} bytes\n" - f" Created: {created}\n" - f" Modified: {modified}" - ) - - -FILESYSTEM_TOOLS: list[Any] = [ - list_directory, - read_file_content, - get_file_metadata, -] diff --git a/app/agents/note_agent.py b/app/agents/note_agent.py deleted file mode 100644 index cae644b..0000000 --- a/app/agents/note_agent.py +++ /dev/null @@ -1,139 +0,0 @@ -"""Note agent — Markdown note management (list, get, create, update, delete).""" - -from __future__ import annotations - -import re -from typing import Any - -from langchain_core.tools import tool - -from app.core.llm import embed -from app.core.ws_context import execute_on_client - -_UUID_RE = re.compile( - r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$" -) - - -def _is_uuid(value: str) -> bool: - return bool(_UUID_RE.match(value)) - -NOTE_SYSTEM_PROMPT = ( - "You are a note-taking assistant. You help users create, retrieve, update,\n" - "and delete Markdown notes in their workspace.\n\n" - "Rules:\n" - " - content is always Markdown; preserve formatting when updating\n" - " - project_id is optional; link a note to a project when mentioned\n" - " - When updating, call get_note first if you need to read existing content\n" - " before appending or replacing sections\n" - " - list_notes without project_id returns all notes; scope with project_id\n" - " when the user is working within a specific project\n" - " - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n" - " - Do not fabricate note content — reflect what the user provides or what\n" - " is already in the note (retrieved via get_note)." -) - - -@tool -async def list_notes(project_id: str = "") -> str: - """List notes, optionally scoped to a project by project_id.""" - normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" - result = await execute_on_client( - action="select", - table="notes", - filters={"projectId": normalized_project_id or None}, - ) - rows = result.get("rows", []) - if not rows: - return "No notes found." - lines = [f"- {r['title']} (id: {r['id']})" for r in rows] - return f"Found {len(rows)} note(s):\n" + "\n".join(lines) - - -@tool -async def get_note(note_id: str) -> str: - """Fetch a single note by its UUID to read its full Markdown content.""" - result = await execute_on_client(action="get", table="notes", data={"id": note_id}) - row = result.get("row") - if not row: - return f"Note {note_id} not found." - return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}" - - -@tool -async def create_note( - title: str, - content: str, - project_id: str = "", -) -> str: - """Create a new note. - title: note heading (required) - content: Markdown body text (required) - project_id: optional UUID linking this note to a project - """ - result = await execute_on_client( - action="insert", - table="notes", - data={ - "title": title, - "content": content, - "projectId": project_id or None, - }, - ) - row = result["row"] - # Index the note content in the vector store. - vector = await embed(content) - await execute_on_client( - action="vector_upsert", - data={"id": row["id"], "projectId": row.get("projectId"), "content": content}, - vector=vector, - ) - return f"Note created: '{row['title']}' (id: {row['id']})." - - -@tool -async def update_note( - note_id: str, - title: str = "", - content: str = "", -) -> str: - """Update an existing note. Only pass fields that should change. - note_id: UUID of the note (required) - If you need to preserve existing content, call get_note first. - """ - updates: dict[str, Any] = {} - if title: - updates["title"] = title - if content: - updates["content"] = content - result = await execute_on_client( - action="update", - table="notes", - data={"id": note_id, "updates": updates}, - ) - row = result["row"] - # Re-index if content changed. - if content: - vector = await embed(content) - await execute_on_client( - action="vector_upsert", - data={"id": note_id, "projectId": row.get("projectId"), "content": content}, - vector=vector, - ) - return f"Note updated: '{row['title']}' (id: {row['id']})." - - -@tool -async def delete_note(note_id: str) -> str: - """Delete a note permanently by its UUID.""" - await execute_on_client(action="delete", table="notes", data={"id": note_id}) - return f"Note {note_id} deleted." - - -NOTE_TOOLS: list[Any] = [ - list_notes, - get_note, - create_note, - update_note, - delete_note, -] diff --git a/app/agents/project_agent.py b/app/agents/project_agent.py deleted file mode 100644 index a07da0e..0000000 --- a/app/agents/project_agent.py +++ /dev/null @@ -1,143 +0,0 @@ -"""Project agent — full lifecycle management (list, get, create, update, archive, delete).""" - -from __future__ import annotations - -from typing import Any - -from langchain_core.tools import tool - -from app.core.ws_context import execute_on_client - -PROJECT_SYSTEM_PROMPT = ( - "You are a project management assistant. You help users create, find,\n" - "update, and archive projects in their workspace.\n\n" - "Rules:\n" - " - status must be one of: active, archived\n" - " - client_id is optional; link to a client only when explicitly mentioned\n" - " - ai_summary is populated only when the user asks for a project summary;\n" - " derive it from context data — do not fabricate content\n" - " - Use list_projects for scoped queries; list_all_projects only when the\n" - " user wants a complete cross-client view including archived projects\n" - " - get_project requires a project UUID; resolve the ID first by calling\n" - " list_projects if you only have a project name\n" - " - Prefer archiving (update_project status=archived) over deletion;\n" - " only call delete_project when the user explicitly confirms deletion." -) - - -@tool -async def list_projects( - client_id: str = "", - include_archived: int = 0, -) -> str: - """List projects, optionally filtered by client_id. - include_archived: 1 to include archived projects, 0 for active only (default). - """ - result = await execute_on_client( - action="select", - table="projects", - filters={ - "clientId": client_id or None, - "includeArchived": bool(include_archived), - }, - ) - rows = result.get("rows", []) - if not rows: - return "No projects found." - lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows] - return f"Found {len(rows)} project(s):\n" + "\n".join(lines) - - -@tool -async def list_all_projects() -> str: - """List every project regardless of client or status. - Use only when the user wants a complete cross-client overview. - """ - result = await execute_on_client(action="select", table="projects") - rows = result.get("rows", []) - if not rows: - return "No projects found." - lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows] - return f"All projects ({len(rows)}):\n" + "\n".join(lines) - - -@tool -async def get_project(project_id: str) -> str: - """Fetch a single project by its UUID.""" - result = await execute_on_client(action="get", table="projects", data={"id": project_id}) - row = result.get("row") - if not row: - return f"Project {project_id} not found." - return ( - f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, " - f"clientId: {row.get('clientId', 'none')})" - ) - - -@tool -async def create_project( - name: str, - client_id: str = "", -) -> str: - """Create a new project. - name: human-readable project name (required) - client_id: optional UUID of the owning client - """ - result = await execute_on_client( - action="insert", - table="projects", - data={"name": name, "clientId": client_id or None}, - ) - row = result["row"] - return f"Project created: '{row['name']}' (id: {row['id']})" - - -@tool -async def update_project( - project_id: str, - name: str = "", - client_id: str = "", - status: str = "", - ai_summary: str = "", -) -> str: - """Update a project. Only pass fields that should change. - project_id: UUID of the project (required) - status: active | archived - ai_summary: AI-generated summary text (populate only when explicitly requested) - """ - updates: dict[str, Any] = {} - if name: - updates["name"] = name - if client_id: - updates["clientId"] = client_id - if status: - updates["status"] = status - if ai_summary: - updates["aiSummary"] = ai_summary - result = await execute_on_client( - action="update", - table="projects", - data={"id": project_id, "updates": updates}, - ) - row = result["row"] - return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})" - - -@tool -async def delete_project(project_id: str) -> str: - """Permanently delete a project and orphan its tasks. - IMPORTANT: prefer update_project(status='archived') unless the user - has explicitly confirmed they want permanent deletion. - """ - await execute_on_client(action="delete", table="projects", data={"id": project_id}) - return f"Project {project_id} permanently deleted." - - -PROJECT_TOOLS: list[Any] = [ - list_projects, - list_all_projects, - get_project, - create_project, - update_project, - delete_project, -] diff --git a/app/agents/task_agent.py b/app/agents/task_agent.py deleted file mode 100644 index 5be4632..0000000 --- a/app/agents/task_agent.py +++ /dev/null @@ -1,238 +0,0 @@ -"""Task agent — full CRUD for tasks and task comments.""" - -from __future__ import annotations - -from datetime import datetime, timezone -import re -from typing import Any - -from langchain_core.tools import tool - -from app.core.ws_context import execute_on_client - -_UUID_RE = re.compile( - r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$" -) - - -def _is_uuid(value: str) -> bool: - return bool(_UUID_RE.match(value)) - -TASK_SYSTEM_PROMPT = ( - "You are a task management assistant for a project workspace.\n" - "You create, update, list, and track tasks and their comments.\n\n" - "Rules:\n" - " - status must be one of: todo, in_progress, done\n" - " - priority must be one of: high, medium, low\n" - " - due_date is a Unix timestamp in milliseconds; convert human dates\n" - " - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n" - " - project_id is optional; link to a project when the user mentions one\n" - " - is_ai_suggested: 1 only when proactively proposing a task the user\n" - " did not explicitly request; 0 otherwise\n" - " - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n" - " - Use list_tasks_due_today for 'what's due today' queries\n" - " - For update_task, use -1 for integer fields you do not want to change\n" - " - Always confirm the action in plain, user-friendly language." -) - - -# ── Task tools ──────────────────────────────────────────────────────── - - -@tool -async def list_tasks( - project_id: str = "", - status: str = "", - search: str = "", - order_by: str = "", -) -> str: - """List tasks, optionally filtered by project_id, status (todo|in_progress|done), - a search string, or an order_by field name (dueDate|priority|createdAt).""" - normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" - result = await execute_on_client( - action="select", - table="tasks", - filters={ - "projectId": normalized_project_id or None, - "status": status or None, - "search": search or None, - "orderBy": order_by or None, - }, - ) - rows = result.get("rows", []) - if not rows: - return "No tasks found matching the given filters." - lines = [ - f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})" - for r in rows - ] - return f"Found {len(rows)} task(s):\n" + "\n".join(lines) - - -@tool -async def create_task( - title: str, - description: str = "", - status: str = "todo", - priority: str = "medium", - assignees: str = "[]", - due_date: int = 0, - project_id: str = "", - is_ai_suggested: int = 0, -) -> str: - """Create a new task. - title: task title (required) - description: optional details - status: todo | in_progress | done (default: todo) - priority: high | medium | low (default: medium) - assignees: JSON-encoded array of assignee names, e.g. '["Alice"]' - due_date: Unix timestamp in milliseconds; 0 means no due date - project_id: optional UUID of the parent project - is_ai_suggested: 1 if proactively suggested, 0 if user-requested - """ - result = await execute_on_client( - action="insert", - table="tasks", - data={ - "title": title, - "description": description or None, - "status": status, - "priority": priority, - "assignee": assignees, - "dueDate": due_date or None, - "projectId": project_id or None, - "isAiSuggested": is_ai_suggested, - }, - ) - row = result["row"] - return ( - f"Task created: '{row['title']}' " - f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})" - ) - - -@tool -async def update_task( - task_id: str, - title: str = "", - description: str = "", - status: str = "", - priority: str = "", - assignees: str = "", - due_date: int = -1, - project_id: str = "", -) -> str: - """Update fields on an existing task. Only pass fields you want to change. - task_id: the task's UUID (required) - due_date: -1 means unchanged; 0 clears the due date; any positive value sets it - """ - updates: dict[str, Any] = {} - if title: - updates["title"] = title - if description: - updates["description"] = description - if status: - updates["status"] = status - if priority: - updates["priority"] = priority - if assignees: - updates["assignee"] = assignees - if due_date != -1: - updates["dueDate"] = due_date or None - if project_id: - updates["projectId"] = project_id - result = await execute_on_client( - action="update", - table="tasks", - data={"id": task_id, "updates": updates}, - ) - row = result["row"] - return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})" - - -@tool -async def delete_task(task_id: str) -> str: - """Delete a task permanently by its UUID.""" - await execute_on_client(action="delete", table="tasks", data={"id": task_id}) - return f"Task {task_id} deleted." - - -@tool -async def list_tasks_due_today() -> str: - """List all tasks whose due date falls on today's date.""" - now = datetime.now(tz=timezone.utc) - start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000) - end_ms = start_ms + 86_400_000 - 1 # last ms of today - result = await execute_on_client( - action="select", - table="tasks", - filters={"dueDateFrom": start_ms, "dueDateTo": end_ms}, - ) - rows = result.get("rows", []) - if not rows: - return "No tasks are due today." - lines = [ - f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})" - for r in rows - ] - return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines) - - -# ── Task comment tools ──────────────────────────────────────────────── - - -@tool -async def list_task_comments(task_id: str) -> str: - """List all comments on a task by its UUID.""" - result = await execute_on_client( - action="select", - table="taskComments", - filters={"taskId": task_id}, - ) - rows = result.get("rows", []) - if not rows: - return f"No comments found for task {task_id}." - lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows] - return f"Found {len(rows)} comment(s):\n" + "\n".join(lines) - - -@tool -async def add_task_comment(task_id: str, author: str, content: str) -> str: - """Add a comment to a task. - task_id: UUID of the task to comment on - author: name or ID of the comment author - content: comment text - """ - result = await execute_on_client( - action="insert", - table="taskComments", - data={"taskId": task_id, "author": author, "content": content}, - ) - row = result.get("row", {}) - row_author = row.get("author", author) - # Electron payloads can vary (taskId vs task_id). Fall back to input task_id. - row_task_id = row.get("taskId") or row.get("task_id") or task_id - row_comment_id = row.get("id", "unknown") - return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})." - - -@tool -async def delete_task_comment(comment_id: str) -> str: - """Delete a task comment by its UUID.""" - await execute_on_client(action="delete", table="taskComments", data={"id": comment_id}) - return f"Comment {comment_id} deleted." - - -# ── Agent ───────────────────────────────────────────────────────────── - - -TASK_TOOLS: list[Any] = [ - list_tasks, - create_task, - update_task, - delete_task, - list_tasks_due_today, - list_task_comments, - add_task_comment, - delete_task_comment, -] diff --git a/app/agents/timeline_agent.py b/app/agents/timeline_agent.py deleted file mode 100644 index 4c7a217..0000000 --- a/app/agents/timeline_agent.py +++ /dev/null @@ -1,114 +0,0 @@ -"""Timeline agent — project milestone management (list, create, update, delete).""" - -from __future__ import annotations - -import re -from typing import Any - -from langchain_core.tools import tool - -from app.core.ws_context import execute_on_client - -_UUID_RE = re.compile( - r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$" -) - - -def _is_uuid(value: str) -> bool: - return bool(_UUID_RE.match(value)) - -TIMELINE_SYSTEM_PROMPT = ( - "You are a project timeline assistant. Timelines are milestone dates that\n" - "track progress on a project — they are not calendar events.\n\n" - "Rules:\n" - " - project_id is REQUIRED for every create; confirm with the user if unknown\n" - " - For listing, project_id must be a UUID; never pass plain names as project_id\n" - " - date is a Unix timestamp in milliseconds; convert human-readable dates\n" - " - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n" - " - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n" - " - For update_timeline, use -1 for integer fields you do not want to change\n" - " - Listing without a project_id returns all timelines across projects\n" - " - Always echo the title and formatted date in your confirmation." -) - - -@tool -async def list_timelines(project_id: str = "") -> str: - """List timelines. Provide project_id to scope to a specific project.""" - normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" - result = await execute_on_client( - action="select", - table="timelines", - filters={"projectId": normalized_project_id or None}, - ) - rows = result.get("rows", []) - if not rows: - return "No timelines found." - lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows] - return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines) - - -@tool -async def create_timeline( - project_id: str, - title: str, - date: int, - is_ai_suggested: int = 0, -) -> str: - """Create a project timeline (milestone). - project_id: REQUIRED UUID of the parent project - title: descriptive name for the milestone - date: Unix timestamp in milliseconds - is_ai_suggested: 1 if proactively suggested, 0 if user-requested - """ - result = await execute_on_client( - action="insert", - table="timelines", - data={ - "projectId": project_id, - "title": title, - "date": date, - "isAiSuggested": is_ai_suggested, - }, - ) - row = result["row"] - return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})" - - -@tool -async def update_timeline( - timeline_id: str, - title: str = "", - date: int = -1, -) -> str: - """Update a timeline. Only pass fields that should change. - timeline_id: UUID of the timeline (required) - date: -1 means unchanged; any other value sets the new date (ms timestamp) - """ - updates: dict[str, Any] = {} - if title: - updates["title"] = title - if date != -1: - updates["date"] = date - result = await execute_on_client( - action="update", - table="timelines", - data={"id": timeline_id, "updates": updates}, - ) - row = result["row"] - return f"Timeline updated: '{row['title']}' (id: {row['id']})" - - -@tool -async def delete_timeline(timeline_id: str) -> str: - """Delete a timeline permanently by its UUID.""" - await execute_on_client(action="delete", table="timelines", data={"id": timeline_id}) - return f"Timeline {timeline_id} deleted." - - -TIMELINE_TOOLS: list[Any] = [ - list_timelines, - create_timeline, - update_timeline, - delete_timeline, -] diff --git a/app/api/__init__.py b/app/api/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/app/api/deps.py b/app/api/deps.py deleted file mode 100644 index 0339d0d..0000000 --- a/app/api/deps.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Shared FastAPI dependencies. - -``get_current_user`` and ``oauth2_scheme`` live in ``app.api.middleware.auth`` -(the canonical location per Step 9). This module re-exports them so that all -existing route imports (``from app.api.deps import get_current_user``) continue -to work without modification. - -Step 12 will update ``get_current_user`` to fetch the live tier from PostgreSQL -instead of reading it from the JWT payload. -""" - -from app.api.middleware.auth import get_current_user, oauth2_scheme # noqa: F401 - -__all__ = ["get_current_user", "oauth2_scheme"] diff --git a/app/api/middleware/__init__.py b/app/api/middleware/__init__.py deleted file mode 100644 index f67fc41..0000000 --- a/app/api/middleware/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -"""API middleware package. - -Exports the three middleware components introduced in Step 9: - - Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme`` - - Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter) - - Sanitizer: ``SanitizerMiddleware`` -""" - -from app.api.middleware.auth import get_current_user, oauth2_scheme -from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter -from app.api.middleware.sanitizer import SanitizerMiddleware - -__all__ = [ - "get_current_user", - "oauth2_scheme", - "TierRateLimitMiddleware", - "limiter", - "SanitizerMiddleware", -] diff --git a/app/api/middleware/auth.py b/app/api/middleware/auth.py deleted file mode 100644 index 4fcedf5..0000000 --- a/app/api/middleware/auth.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Auth middleware — JWT validation dependency. - -``get_current_user`` is the FastAPI dependency used by all protected routes. -It decodes the Bearer JWT (identity + expiry), then fetches the current tier -from the ``subscriptions`` table so that tier changes take effect immediately -without requiring token re-issue. - -Exempt routes (no JWT required): - - POST /api/v1/auth/register - - POST /api/v1/auth/login - - POST /api/v1/billing/webhook -""" - -from __future__ import annotations - -from fastapi import Depends, HTTPException, status -from fastapi.security import OAuth2PasswordBearer -from jose import JWTError, jwt -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.config.settings import settings -from app.db import get_session -from app.schemas import UserProfile - -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") - - -async def get_current_user( - token: str = Depends(oauth2_scheme), - db: AsyncSession = Depends(get_session), -) -> UserProfile: - """Validate a Bearer JWT and return the authenticated user. - - The JWT is used for identity and expiry only. The tier is fetched live - from the ``subscriptions`` table so that upgrades/downgrades take effect - immediately. Falls back to ``'free'`` when no subscription row exists. - - Raises HTTP 401 on any invalid or expired token. - """ - credentials_exc = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - try: - payload = jwt.decode( - token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] - ) - user_id: str | None = payload.get("sub") - email: str | None = payload.get("email") - if not user_id or not email: - raise credentials_exc - except JWTError: - raise credentials_exc - - # Live tier lookup — subscription row is the authoritative source. - # In dev, fall back to 'power' (unlimited) so quota limits don't - # block local development when no Stripe subscription exists. - from app.models import Subscription, User # noqa: PLC0415 - - result = await db.execute( - select(Subscription.tier).where(Subscription.user_id == user_id) - ) - default_tier = "power" if settings.ENV == "dev" else "free" - tier: str = result.scalar_one_or_none() or default_tier - - # Fetch name/surname from user row. - user_result = await db.execute( - select(User.name, User.surname).where(User.id == user_id) - ) - user_row = user_result.one_or_none() - - return UserProfile( - id=user_id, - email=email, - name=user_row.name if user_row else None, - surname=user_row.surname if user_row else None, - tier=tier, - ) # type: ignore[arg-type] diff --git a/app/api/middleware/rate_limit.py b/app/api/middleware/rate_limit.py deleted file mode 100644 index 4a2af76..0000000 --- a/app/api/middleware/rate_limit.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Tier-aware rate limiting middleware. - -Uses a per-user sliding-window counter (in-process, no Redis required). -The ``slowapi`` Limiter is also exported for optional route-level decoration. - -Limits (requests per minute): - - free: 20 - - pro: 60 - - power: 120 - - team: 200 - -Exempt paths bypass the limiter entirely: - - POST /api/v1/auth/register - - POST /api/v1/auth/login - - POST /api/v1/billing/webhook - - GET /api/v1/health -""" - -from __future__ import annotations - -import json -import time -from collections import defaultdict - -from fastapi import Request, Response -from jose import JWTError, jwt -from slowapi import Limiter -from slowapi.util import get_remote_address -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.types import ASGIApp - -from app.config.settings import settings - -_TIER_LIMITS: dict[str, int] = { - "free": 20, - "pro": 60, - "power": 120, - "team": 200, -} - -_EXEMPT_PATHS: frozenset[str] = frozenset( - { - "/api/v1/auth/register", - "/api/v1/auth/login", - "/api/v1/billing/webhook", - "/api/v1/health", - } -) - - -def _get_user_id_from_jwt(request: Request) -> str: - """Key function for the slowapi Limiter: returns JWT sub or remote IP.""" - auth = request.headers.get("Authorization", "") - token = auth.removeprefix("Bearer ").strip() - if not token: - return get_remote_address(request) - try: - payload = jwt.decode( - token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] - ) - return payload.get("sub") or get_remote_address(request) - except JWTError: - return get_remote_address(request) - - -# Exported Limiter instance — available for optional route-level decoration. -limiter = Limiter(key_func=_get_user_id_from_jwt) - - -class TierRateLimitMiddleware(BaseHTTPMiddleware): - """Sliding-window rate limiter applied globally across all non-exempt routes. - - Each authenticated user gets their own 60-second window sized by tier. - Unauthenticated requests pass through (the auth dependency will reject them - with 401 before the route handler runs). - """ - - def __init__(self, app: ASGIApp) -> None: - super().__init__(app) - # user_id → list of request timestamps (float, seconds since epoch) - self._window: dict[str, list[float]] = defaultdict(list) - - async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override] - if request.url.path in _EXEMPT_PATHS: - return await call_next(request) - - # Extract JWT claims — if no valid token, pass through for auth dep to handle. - auth = request.headers.get("Authorization", "") - token = auth.removeprefix("Bearer ").strip() - if not token: - return await call_next(request) - - try: - payload = jwt.decode( - token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] - ) - user_id: str = payload.get("sub") or get_remote_address(request) - tier: str = payload.get("tier", "free") - except JWTError: - return await call_next(request) - - limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"]) - now = time.monotonic() - window_start = now - 60.0 - - # Slide the window: discard timestamps older than 60 seconds. - timestamps = [t for t in self._window[user_id] if t > window_start] - - if len(timestamps) >= limit: - retry_after = max(1, int(60 - (now - min(timestamps)))) - return Response( - content=json.dumps( - { - "detail": ( - f"Rate limit exceeded ({limit} req/min for {tier} tier). " - f"Retry in {retry_after}s." - ) - } - ), - status_code=429, - headers={ - "Retry-After": str(retry_after), - "Content-Type": "application/json", - }, - ) - - timestamps.append(now) - self._window[user_id] = timestamps - return await call_next(request) diff --git a/app/api/middleware/sanitizer.py b/app/api/middleware/sanitizer.py deleted file mode 100644 index 570937f..0000000 --- a/app/api/middleware/sanitizer.py +++ /dev/null @@ -1,139 +0,0 @@ -"""Response sanitizer middleware. - -Scans JSON responses from the /api/v1/chat endpoint and strips any fragments -that could reveal server-side prompt IP: - - System prompt openers ("You are a/an/the …") - - Agent routing metadata ("Available agents:", "intent classifier", …) - - LangChain tool schema fragments (``"type": "function"``) - - Internal reasoning markers (, , [INST], …) - - Exact-match known prompt fingerprints - -Binary responses (storage blobs, backup data) are never touched — the -middleware only activates for paths under /api/v1/chat. - -Any sanitisation event is logged as a WARNING with the request path and the -names of the fields that were modified. -""" - -from __future__ import annotations - -import json -import logging -import re - -from fastapi import Request, Response -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.types import ASGIApp - -logger = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# Detection patterns — order matters: fingerprints checked first (exact), -# then compiled regexes. -# --------------------------------------------------------------------------- - -_FINGERPRINTS: tuple[str, ...] = ( - "You are an intent classifier", - "Respond with just the agent name", - "Summarize these agent results", - "Available agents:", - "route to:", -) - -_PATTERNS: tuple[re.Pattern[str], ...] = ( - re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL), - re.compile(r"Available agents\s*:", re.IGNORECASE), - re.compile(r"\bintent classifier\b", re.IGNORECASE), - re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema - re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE), - re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers - re.compile(r"route\s+to\s*:", re.IGNORECASE), - re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE), -) - - -def _sanitize_text(text: str) -> tuple[str, bool]: - """Scan *text* for prompt fragments and replace matches with ``[REDACTED]``. - - Returns ``(cleaned_text, was_changed)``. - """ - # Fingerprint check — if any exact phrase is present, redact the whole string. - for fp in _FINGERPRINTS: - if fp in text: - return "[REDACTED]", True - - changed = False - for pattern in _PATTERNS: - new_text, n = pattern.subn("[REDACTED]", text) - if n: - text = new_text - changed = True - - return text, changed - - -class SanitizerMiddleware(BaseHTTPMiddleware): - """Strip prompt IP from /api/v1/chat JSON responses.""" - - def __init__(self, app: ASGIApp) -> None: - super().__init__(app) - - async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override] - response: Response = await call_next(request) - - # Only process chat endpoint responses. - if not request.url.path.startswith("/api/v1/chat"): - return response - - # Read body — collect streaming chunks. - body_bytes = b"" - async for chunk in response.body_iterator: - body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode() - - # Skip non-JSON bodies (shouldn't happen on /chat, but be safe). - try: - body = json.loads(body_bytes.decode("utf-8")) - except (json.JSONDecodeError, UnicodeDecodeError): - return Response( - content=body_bytes, - status_code=response.status_code, - headers=dict(response.headers), - media_type=response.media_type, - ) - - if not isinstance(body, dict): - return Response( - content=body_bytes, - status_code=response.status_code, - headers=dict(response.headers), - media_type=response.media_type, - ) - - # Walk top-level string fields and sanitise. - sanitised_fields: list[str] = [] - for key, value in body.items(): - if isinstance(value, str): - cleaned, changed = _sanitize_text(value) - if changed: - body[key] = cleaned - sanitised_fields.append(key) - - if sanitised_fields: - logger.warning( - "Sanitizer redacted prompt fragments", - extra={ - "path": request.url.path, - "fields": sanitised_fields, - }, - ) - - new_body = json.dumps(body).encode("utf-8") - headers = dict(response.headers) - headers["content-length"] = str(len(new_body)) - - return Response( - content=new_body, - status_code=response.status_code, - headers=headers, - media_type="application/json", - ) diff --git a/app/api/routes/__init__.py b/app/api/routes/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/app/api/routes/agent_setup.py b/app/api/routes/agent_setup.py deleted file mode 100644 index 2052d0b..0000000 --- a/app/api/routes/agent_setup.py +++ /dev/null @@ -1,406 +0,0 @@ -"""Chatbot Journey — WS-based guided conversation to build an agent prompt_template. - -The journey is driven entirely through WebSocket frames (no REST endpoints). -The device WS handler dispatches ``journey_start`` and ``journey_message`` -frames to the functions exported here. - -Journey flow: - 1. FE sends ``journey_start`` frame with basic agent config (directory, - data_types, schedule). - 2. Server creates an in-memory session, sets up a WS executor so the - setup LLM can use file-system tools, does a first directory scrape, - and sends back a ``journey_reply`` with the first question. - 3. FE sends ``journey_message`` frames for each user reply. - 4. Server appends the user message, calls the LLM (which may read files - via tools), and sends back a ``journey_reply``. - 5. After 3-5 turns the LLM wraps up by emitting a ``prompt_template`` - block delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``. - 6. Server parses the block, sends ``journey_reply`` with ``done=True`` - and the template. FE stores it locally. -""" - -from __future__ import annotations - -import json -import logging -import time -import uuid -from dataclasses import dataclass, field -from typing import Any - -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage - -from app.agents.filesystem_agent import FILESYSTEM_TOOLS -from app.core.llm import get_llm - -logger = logging.getLogger(__name__) - -# ── Session TTL ─────────────────────────────────────────────────────────── - -_SESSION_TTL_SECONDS: int = 1800 # 30 minutes - -# Sentinel strings used to delimit the LLM-produced prompt_template. -_TEMPLATE_START = "PROMPT_TEMPLATE_START" -_TEMPLATE_END = "PROMPT_TEMPLATE_END" - -# Minimum turns before we consider nudging the LLM to wrap up. -_MIN_TURNS_BEFORE_NUDGE: int = 3 -# Hard cap to avoid infinite loops (safety net, not the primary stopping criterion). -_MAX_TURNS: int = 15 -# Max tool-calling steps per LLM invocation. -_MAX_TOOL_STEPS: int = 6 - -# ── In-memory session store ─────────────────────────────────────────────── - - -@dataclass -class JourneySession: - session_id: str - user_id: str - agent_type: str # "local" | "cloud" - directory: str - data_types: list[str] - history: list[dict[str, Any]] = field(default_factory=list) - system_prompt: str = "" - created_at: float = field(default_factory=time.monotonic) - - def is_expired(self) -> bool: - return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS - - -# session_id → session -_sessions: dict[str, JourneySession] = {} - - -def get_journey_session(session_id: str, user_id: str) -> JourneySession | None: - """Retrieve session; return None on missing, expired, or wrong owner.""" - s = _sessions.get(session_id) - if s is None or s.is_expired(): - _sessions.pop(session_id, None) - return None - if s.user_id != user_id: - return None - return s - - -# ── System prompt builder ───────────────────────────────────────────────── - -_SYSTEM_PROMPT_TEMPLATE = """\ -You are a friendly assistant helping a freelancer configure a data-extraction agent. -Your job is to understand exactly what data the user wants to extract from their -local directory and produce a detailed prompt_template that a separate AI will use -as its instruction set. - -The extraction agent already has this base behaviour built in: - - Reads each file using file-system tools. - - Creates records (tasks, notes, timelines, projects) via CRUD tools. - - Sets isAiSuggested=1 on every new record. - - Only extracts data explicitly present in the files — it never invents information. -The user's custom prompt is appended AFTER this base behaviour, so focus on -what to look for and how to map it — not on the general extraction mechanics. - -You have access to file-system tools to explore the user's directory: -- list_directory: to see folder structure -- read_file_content: to peek at file contents -- get_file_metadata: to check file info - -The user's configured directory is: {directory} -Target data types: {data_types} - -IMPORTANT — project assignment is handled automatically by the main agent runner -before the custom prompt is ever used. You MUST NOT ask the user about projects, -projectId, or how to link records to projects. Never include projectId logic or -project creation instructions in the generated prompt_template. - -Start by exploring the directory to understand its structure. Then ask concise, -focused questions one at a time. Cover these topics (not necessarily in this order): - 1. The type and format of the source content (confirmed by your exploration). - 2. How fields should be mapped (e.g. filename → task title). - 3. Priority or status rules (e.g. "urgent" keyword → high priority). - 4. Any special handling, date extraction, or exclusions. - -Once you reach 90% confidence, output the final prompt_template between these exact -markers on their own lines: - -{template_start} - -{template_end} - -The prompt_template must be a self-contained instruction for an AI that reads files -and must perform CRUD operations using tools to create records. It should specify: - - What entity types to create (tasks, notes, timelines) — never projects. - - How to map file content to record fields (camelCase: title, status, priority, - dueDate, content, etc.) — never include projectId. - - That isAiSuggested must be set to 1 on every new record. - - Concrete examples of mappings based on what you discovered in the directory. - -{existing_section}\ -Keep asking clarifying questions until you are at least 90% confident you have -enough information to generate an accurate prompt_template. Once you reach that -confidence level, stop asking and produce the final template immediately. -Begin by exploring the directory, then ask your first question.\ -""" - - -def _build_system_prompt( - directory: str, - data_types: list[str], - existing_template: str | None = None, -) -> str: - existing_section = ( - f"\nThe user already has the following prompt_template — refine it based on their answers:\n" - f"---\n{existing_template}\n---\n" - if existing_template - else "" - ) - return _SYSTEM_PROMPT_TEMPLATE.format( - directory=directory, - data_types=", ".join(data_types), - template_start=_TEMPLATE_START, - template_end=_TEMPLATE_END, - existing_section=existing_section, - ) - - -# ── Template extraction ─────────────────────────────────────────────────── - - -def _extract_template(text: str) -> str | None: - """Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None.""" - if _TEMPLATE_START not in text or _TEMPLATE_END not in text: - return None - start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START) - end_idx = text.index(_TEMPLATE_END) - return text[start_idx:end_idx].strip() or None - - -# ── LLM call with tool support ─────────────────────────────────────────── - - -def _as_text(content: Any) -> str: - if content is None: - return "" - if isinstance(content, str): - return content - if isinstance(content, list): - parts: list[str] = [] - for item in content: - if isinstance(item, str): - parts.append(item) - elif isinstance(item, dict): - text = item.get("text") - if isinstance(text, str): - parts.append(text) - return "".join(parts) - return str(content) - - -async def _call_llm_with_tools( - system_prompt: str, - history: list[dict[str, Any]], - tools: list[Any], -) -> str: - """Build LangChain messages from history and invoke the LLM with tools. - - Handles tool-calling loops: if the LLM calls tools, execute them and - continue until a final text response is produced. - """ - messages: list[Any] = [SystemMessage(content=system_prompt)] - for turn in history: - if turn["role"] == "user": - messages.append(HumanMessage(content=turn["content"])) - else: - messages.append(AIMessage(content=turn["content"])) - - llm = get_llm(model=None, temperature=0.4) - llm_with_tools = llm.bind_tools(tools) - tool_map = {tool_def.name: tool_def for tool_def in tools} - - for _ in range(_MAX_TOOL_STEPS): - response: AIMessage = await llm_with_tools.ainvoke(messages) - messages.append(response) - - if not response.tool_calls: - return _as_text(response.content) - - for call in response.tool_calls: - call_name = str(call.get("name", "")) - call_args = call.get("args", {}) - logger.info( - "agent_setup: journey tool_call name=%s args=%s", - call_name, - json.dumps(call_args, ensure_ascii=True)[:500], - ) - - tool_fn = tool_map.get(call_name) - if tool_fn is None: - tool_output = f"Unknown tool: {call_name}" - else: - tool_output = await tool_fn.ainvoke(call_args) - - logger.info( - "agent_setup: journey tool_result name=%s output=%s", - call_name, - str(tool_output)[:800], - ) - messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) - - # Fallback: exceeded max steps. - final = await llm.ainvoke(messages) - return _as_text(final.content) - - -# ── Journey handlers (called from device_ws.py) ────────────────────────── - - -async def handle_journey_start( - user_id: str, - frame: dict[str, Any], -) -> dict[str, Any]: - """Handle a ``journey_start`` WS frame. - - Creates a session, runs the setup LLM with directory exploration, - and returns the ``journey_reply`` payload. - """ - agent_type = frame.get("agent_type", "local") - directory = frame.get("directory", "") - data_types = frame.get("data_types", []) - existing_template = frame.get("existing_template") - - # Use the session_id provided by the FE so the reply matches the - # listener key; fall back to a generated one if absent. - session_id = frame.get("session_id") or str(uuid.uuid4()) - system_prompt = _build_system_prompt(directory, data_types, existing_template) - - session = JourneySession( - session_id=session_id, - user_id=user_id, - agent_type=agent_type, - directory=directory, - data_types=data_types, - system_prompt=system_prompt, - ) - - # The LLM will explore the directory using FILESYSTEM_TOOLS via the - # ws_context executor (already set by the WS handler before calling us). - # Seed with an initial user message — some providers (e.g. GitHub Copilot) - # require at least one user/input message to be present. - seed_history: list[dict[str, Any]] = [ - {"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."}, - ] - ai_reply = await _call_llm_with_tools( - system_prompt=system_prompt, - history=seed_history, - tools=list(FILESYSTEM_TOOLS), - ) - - session.history.extend(seed_history) - session.history.append({"role": "assistant", "content": ai_reply}) - _sessions[session_id] = session - - logger.info( - "agent_setup: journey session %s started for user %s (directory=%s)", - session_id, - user_id, - directory, - ) - - # Check if the LLM produced the template on the first turn (unlikely but possible). - prompt_template = _extract_template(ai_reply) - done = prompt_template is not None - - display_message = ai_reply - if done: - display_message = ( - ai_reply[: ai_reply.index(_TEMPLATE_START)].strip() - or "Here is your agent configuration. You can save it or continue refining." - ) - _sessions.pop(session_id, None) - - return { - "type": "journey_reply", - "session_id": session_id, - "message": display_message, - "done": done, - "prompt_template": prompt_template, - } - - -async def handle_journey_message( - user_id: str, - frame: dict[str, Any], -) -> dict[str, Any]: - """Handle a ``journey_message`` WS frame. - - Appends the user message, calls the LLM, and returns the - ``journey_reply`` payload. - """ - session_id = frame.get("session_id", "") - message = frame.get("message", "") - - session = get_journey_session(session_id, user_id) - if session is None: - return { - "type": "journey_reply", - "session_id": session_id, - "message": "Journey session not found or expired. Please start a new setup.", - "done": True, - "prompt_template": None, - } - - # Append user turn. - session.history.append({"role": "user", "content": message}) - - # Call the LLM with tools. - ai_reply = await _call_llm_with_tools( - system_prompt=session.system_prompt, - history=session.history, - tools=list(FILESYSTEM_TOOLS), - ) - - session.history.append({"role": "assistant", "content": ai_reply}) - - # Check if the LLM produced the final template. - prompt_template = _extract_template(ai_reply) - done = prompt_template is not None - - # If the LLM didn't produce a template, nudge it once it has asked enough - # questions (>= _MIN_TURNS_BEFORE_NUDGE) or hits the hard safety cap. - if not done: - turns = sum(1 for t in session.history if t["role"] == "user") - if turns >= _MAX_TURNS: - nudge_content = ( - "[System: You have enough information. Please generate the final " - f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]" - ) - session.history.append({"role": "user", "content": nudge_content}) - - nudge_reply = await _call_llm_with_tools( - system_prompt=session.system_prompt, - history=session.history, - tools=list(FILESYSTEM_TOOLS), - ) - session.history.append({"role": "assistant", "content": nudge_reply}) - - prompt_template = _extract_template(nudge_reply) - if prompt_template is not None: - done = True - ai_reply = nudge_reply - - display_message = ai_reply - if done: - display_message = ( - ai_reply[: ai_reply.index(_TEMPLATE_START)].strip() - if _TEMPLATE_START in ai_reply - else "Here is your agent configuration. You can save it or continue refining." - ) - _sessions.pop(session_id, None) - logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id) - - return { - "type": "journey_reply", - "session_id": session_id, - "message": display_message, - "done": done, - "prompt_template": prompt_template, - } diff --git a/app/api/routes/agents.py b/app/api/routes/agents.py deleted file mode 100644 index 30ecfc9..0000000 --- a/app/api/routes/agents.py +++ /dev/null @@ -1,222 +0,0 @@ -"""Agent routes. - -Backend responsibilities are intentionally minimal: - GET /agents/catalog — static catalog for UI display - POST /agents/can-create — billing eligibility check - POST /agents/trigger — trigger a local agent run - -Agent configuration is owned by the Electron app and is not persisted -in backend agent-config tables. -""" - -from __future__ import annotations - -import asyncio -import uuid -from datetime import datetime, timedelta, timezone - -from fastapi import APIRouter, Depends, HTTPException, status -from sqlalchemy import func, select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.api.deps import get_current_user -from app.billing.tier_manager import FEATURES -from app.core.agent_runner import is_agent_running, run_local_agent -from app.core.device_manager import device_manager -from app.db import get_session -from app.models import AgentRunLog, LocalAgentConfig -from app.schemas import ( - AgentCatalogItem, - AgentCreationCheckRequest, - AgentCreationCheckResponse, - AgentRunLogResponse, - AgentTriggerRequest, - UserProfile, -) - -router = APIRouter(prefix="/agents", tags=["agents"]) - - -# ── Datetime helpers ────────────────────────────────────────────────── - -def _dt_ms(dt: datetime) -> int: - return int(dt.timestamp() * 1000) - - -def _dt_ms_opt(dt: datetime | None) -> int | None: - return int(dt.timestamp() * 1000) if dt else None - - -def _to_data_types(values: list[str]) -> list[str]: - normalize = { - "task": "tasks", "tasks": "tasks", - "note": "notes", "notes": "notes", - "timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines", - "project": "projects", "projects": "projects", - } - seen: set[str] = set() - result: list[str] = [] - for v in values: - mapped = normalize.get(v) - if mapped and mapped not in seen: - seen.add(mapped) - result.append(mapped) - return result - - -def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse: - return AgentRunLogResponse( - id=log.id, - agent_id=log.agent_id, - agent_type=log.agent_type, # type: ignore[arg-type] - status=log.status, # type: ignore[arg-type] - items_processed=log.items_processed, - items_created=log.items_created, - errors=log.errors or [], - started_at=_dt_ms(log.started_at), - completed_at=_dt_ms_opt(log.completed_at), - ) - - -def _enforce_agent_limit(tier: str, current_count: int) -> int: - limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"] - if limit != -1 and current_count >= limit: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.", - ) - return limit - - -async def _enforce_run_frequency( - tier: str, - user_id: str, - db: AsyncSession, -) -> None: - """Raise HTTP 402 if the user has exceeded their daily batch run limit.""" - limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"] - if limit == -1: - return # unlimited - - today_start = datetime.now(timezone.utc).replace( - hour=0, minute=0, second=0, microsecond=0 - ) - result = await db.execute( - select(func.count(AgentRunLog.id)).where( - AgentRunLog.user_id == user_id, - AgentRunLog.started_at >= today_start, - ) - ) - runs_today: int = result.scalar_one() - - if runs_today >= limit: - raise HTTPException( - status_code=status.HTTP_402_PAYMENT_REQUIRED, - detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.", - ) - - -# ── Catalog ─────────────────────────────────────────────────────────── - -@router.get("/catalog", response_model=list[AgentCatalogItem]) -async def get_agent_catalog( - current_user: UserProfile = Depends(get_current_user), -) -> list[AgentCatalogItem]: - """Return the static list of available agent types and their descriptions.""" - return [ - AgentCatalogItem( - type="local_directory", - name="Local Directory Monitor", - description="Watches local directories, extracts data from files using AI", - ), - AgentCatalogItem( - type="gmail", - name="Gmail Connector", - description="Scans Gmail inbox, extracts tasks/notes from emails", - ), - AgentCatalogItem( - type="teams", - name="Microsoft Teams Connector", - description="Monitors Teams messages, extracts action items", - ), - AgentCatalogItem( - type="outlook", - name="Outlook Connector", - description="Scans Outlook inbox, extracts tasks/notes", - ), - ] - - -@router.post("/can-create", response_model=AgentCreationCheckResponse) -async def can_create_agent( - body: AgentCreationCheckRequest, - current_user: UserProfile = Depends(get_current_user), -) -> AgentCreationCheckResponse: - """Check if the user can create one more agent based on billing tier. - - Since configuration is client-owned, the Electron app sends its current - active agent count and the backend applies tier limits. - """ - limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"] - allowed = limit == -1 or body.active_agents < limit - return AgentCreationCheckResponse( - allowed=allowed, - tier=current_user.tier, - active_agents=body.active_agents, - limit=limit, - ) - - -@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED) -async def trigger_agent_run( - body: AgentTriggerRequest, - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> AgentRunLogResponse: - """Trigger a local agent run using client-provided configuration.""" - _enforce_agent_limit(current_user.tier, body.active_agents) - await _enforce_run_frequency(current_user.tier, current_user.id, db) - - config = LocalAgentConfig( - id=str(uuid.uuid4()), - user_id=current_user.id, - device_id=body.device_id, - name="Local Directory Monitor", - directory_paths=[body.directory], - data_types=_to_data_types(body.what_to_extract), - prompt_template=body.custom_agent_prompt, - file_extensions=[], - schedule_cron=body.batch_interval, - enabled=True, - ) - - # Use the FE's stable agent_id if provided, fall back to the ephemeral config id. - stable_agent_id = body.agent_id or config.id - - if is_agent_running(stable_agent_id): - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail="Agent is already running. Only one run per agent is allowed at a time.", - ) - - run_log = AgentRunLog( - agent_id=stable_agent_id, - agent_type="local", - user_id=current_user.id, - status="running", - ) - db.add(run_log) - await db.commit() - await db.refresh(run_log) - - run_context = { - "type": "agent_batch", - "run_id": run_log.id, - "agent_id": stable_agent_id, - } - - asyncio.create_task( - run_local_agent(current_user.id, config, run_log, device_manager, run_context) - ) - - return _to_run_log_response(run_log) diff --git a/app/api/routes/auth.py b/app/api/routes/auth.py deleted file mode 100644 index 1ab10ea..0000000 --- a/app/api/routes/auth.py +++ /dev/null @@ -1,235 +0,0 @@ -"""Auth routes: register, login, refresh, me. - -Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens -tables). Passwords are hashed with bcrypt; refresh tokens are stored as -SHA-256 hashes so plaintext never reaches the DB. -""" - -from __future__ import annotations - -import hashlib -import time -import uuid -from datetime import datetime, timedelta, timezone - -import bcrypt -from cryptography.fernet import Fernet -from fastapi import APIRouter, Depends, HTTPException, status -from jose import jwt -from pydantic import BaseModel -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.api.deps import get_current_user -from app.config.settings import settings -from app.db import get_session -from app.models import RefreshToken, User -from app.schemas import AuthTokens, UserProfile - -router = APIRouter(prefix="/auth", tags=["auth"]) - - -# ── Internal helpers ───────────────────────────────────────────────── - - -def _hash_password(password: str) -> str: - return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() - - -def _verify_password(password: str, hashed: str) -> bool: - return bcrypt.checkpw(password.encode(), hashed.encode()) - - -def _hash_token(plain_token: str) -> str: - """SHA-256 of the plain refresh token string.""" - return hashlib.sha256(plain_token.encode()).hexdigest() - - -def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]: - """Return (signed JWT, expires_at_ms).""" - now = int(time.time()) - exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60 - payload = { - "sub": user_id, - "email": email, - "tier": tier, - "exp": exp, - "iat": now, - } - token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM) - return token, exp * 1000 # ms for client - - -# ── Request bodies ──────────────────────────────────────────────────── - - -class _RegisterRequest(BaseModel): - email: str - password: str - name: str | None = None - surname: str | None = None - - -class _LoginRequest(BaseModel): - email: str - password: str - - -class _RefreshRequest(BaseModel): - refresh_token: str - - -# ── Routes ──────────────────────────────────────────────────────────── - - -@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED) -async def register( - body: _RegisterRequest, - db: AsyncSession = Depends(get_session), -) -> AuthTokens: - """Create a new account and return JWT tokens.""" - existing = await db.execute(select(User).where(User.email == body.email)) - if existing.scalar_one_or_none() is not None: - raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered") - - user = User( - id=str(uuid.uuid4()), - email=body.email, - name=body.name, - surname=body.surname, - password_hash=_hash_password(body.password), - tier="free", - encryption_key=Fernet.generate_key().decode(), - ) - db.add(user) - await db.flush() # get user.id without committing - - plain_token = str(uuid.uuid4()) - expires_at = datetime.now(timezone.utc) + timedelta( - days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS - ) - rt = RefreshToken( - user_id=user.id, - token_hash=_hash_token(plain_token), - expires_at=expires_at, - ) - db.add(rt) - await db.commit() - - access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier) - return AuthTokens( - access_token=access_token, - refresh_token=plain_token, - expires_at=expires_at_ms, - ) - - -@router.post("/login", response_model=AuthTokens) -async def login( - body: _LoginRequest, - db: AsyncSession = Depends(get_session), -) -> AuthTokens: - """Validate credentials and return JWT tokens.""" - result = await db.execute(select(User).where(User.email == body.email)) - user = result.scalar_one_or_none() - if user is None or not _verify_password(body.password, user.password_hash): - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials") - - plain_token = str(uuid.uuid4()) - expires_at = datetime.now(timezone.utc) + timedelta( - days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS - ) - rt = RefreshToken( - user_id=user.id, - token_hash=_hash_token(plain_token), - expires_at=expires_at, - ) - db.add(rt) - await db.commit() - - access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier) - return AuthTokens( - access_token=access_token, - refresh_token=plain_token, - expires_at=expires_at_ms, - ) - - -@router.post("/refresh", response_model=AuthTokens) -async def refresh( - body: _RefreshRequest, - db: AsyncSession = Depends(get_session), -) -> AuthTokens: - """Rotate a refresh token and return a new token pair.""" - token_hash = _hash_token(body.refresh_token) - result = await db.execute( - select(RefreshToken).where(RefreshToken.token_hash == token_hash) - ) - rt = result.scalar_one_or_none() - - now = datetime.now(timezone.utc) - if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token") - - # Rotate: delete old token, issue new one. - await db.delete(rt) - - user_result = await db.execute(select(User).where(User.id == rt.user_id)) - user = user_result.scalar_one_or_none() - if user is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found") - - plain_token = str(uuid.uuid4()) - new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS) - new_rt = RefreshToken( - user_id=user.id, - token_hash=_hash_token(plain_token), - expires_at=new_expires, - ) - db.add(new_rt) - await db.commit() - - access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier) - return AuthTokens( - access_token=access_token, - refresh_token=plain_token, - expires_at=expires_at_ms, - ) - - -class _UpdateProfileRequest(BaseModel): - name: str | None = None - surname: str | None = None - - -@router.get("/me", response_model=UserProfile) -async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile: - """Return the profile for the authenticated user.""" - return current_user - - -@router.put("/me", response_model=UserProfile) -async def update_profile( - body: _UpdateProfileRequest, - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> UserProfile: - """Update the authenticated user's name and surname.""" - result = await db.execute(select(User).where(User.id == current_user.id)) - user = result.scalar_one() - - if body.name is not None: - user.name = body.name - if body.surname is not None: - user.surname = body.surname - - await db.commit() - await db.refresh(user) - - return UserProfile( - id=user.id, - email=user.email, - name=user.name, - surname=user.surname, - tier=current_user.tier, - ) diff --git a/app/api/routes/backup.py b/app/api/routes/backup.py deleted file mode 100644 index 2b8eeae..0000000 --- a/app/api/routes/backup.py +++ /dev/null @@ -1,171 +0,0 @@ -"""Backup routes: upload, download, history, and delete E2E-encrypted backups. - -Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the -PostgreSQL ``backup_metadata`` table. - -IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI -treating "history" as a ``{backup_id}`` path parameter. -""" - -from __future__ import annotations - -import uuid -from email.utils import parsedate_to_datetime - -from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status -from sqlalchemy import func, select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.api.deps import get_current_user -from app.billing.tier_manager import tier_manager -from app.db import get_session -from app.models import BackupMetadata as BackupMetadataModel -from app.schemas import BackupMetadata, UserProfile -from app.storage.blob_store import BlobStore -from app.storage.encryption import reject_if_tampered - -router = APIRouter(prefix="/backup", tags=["backup"]) - -_blob_store = BlobStore() - - -async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int: - """Return total backup bytes stored by *user_id*.""" - result = await db.execute( - select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where( - BackupMetadataModel.user_id == user_id - ) - ) - return int(result.scalar_one()) - - -async def _check_backup_quota( - user: UserProfile, size_bytes: int, db: AsyncSession -) -> None: - """Raise HTTP 402 if the upload would exceed the tier's backup limit.""" - current = await _current_backup_bytes(user.id, db) - tier_manager.enforce_backup_quota( - user.tier, current_bytes=current, additional_bytes=size_bytes - ) - - -@router.put("") -async def upload_backup( - request: Request, - x_backup_version: int = Header(..., alias="X-Backup-Version"), - x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"), - x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"), - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> dict[str, bool]: - """Upload an E2E-encrypted backup blob. - - Metadata is passed via custom headers; the raw body is the encrypted blob. - """ - blob = await request.body() - reject_if_tampered(blob, x_backup_checksum) - await _check_backup_quota(current_user, len(blob), db) - - s3_key = await _blob_store.upload( - current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum - ) - - row = BackupMetadataModel( - id=str(uuid.uuid4()), - user_id=current_user.id, - s3_key=s3_key, - version=x_backup_version, - timestamp=x_backup_timestamp, - checksum=x_backup_checksum, - size_bytes=len(blob), - ) - db.add(row) - await db.commit() - - return {"ok": True} - - -@router.get("/history", response_model=list[BackupMetadata]) -async def backup_history( - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> list[BackupMetadata]: - """Return backup metadata records for the authenticated user (no blob bytes).""" - result = await db.execute( - select(BackupMetadataModel) - .where(BackupMetadataModel.user_id == current_user.id) - .order_by(BackupMetadataModel.timestamp.desc()) - ) - rows = result.scalars().all() - return [ - BackupMetadata( - version=r.version, - timestamp=r.timestamp, - checksum=r.checksum, - chunk_count=1, - ) - for r in rows - ] - - -@router.get("") -async def download_backup( - request: Request, - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> Response: - """Download the latest backup blob. Supports ``If-Modified-Since``.""" - result = await db.execute( - select(BackupMetadataModel) - .where(BackupMetadataModel.user_id == current_user.id) - .order_by(BackupMetadataModel.timestamp.desc()) - .limit(1) - ) - latest = result.scalar_one_or_none() - if latest is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found") - - ims_header = request.headers.get("If-Modified-Since") - if ims_header: - try: - ims_dt = parsedate_to_datetime(ims_header) - ims_ms = int(ims_dt.timestamp() * 1000) - if latest.timestamp <= ims_ms: - return Response(status_code=status.HTTP_304_NOT_MODIFIED) - except Exception: - pass # malformed header — ignore and serve the blob - - blob = await _blob_store.download(current_user.id, latest.s3_key) - return Response( - content=blob, - media_type="application/octet-stream", - headers={ - "X-Backup-Version": str(latest.version), - "X-Backup-Timestamp": str(latest.timestamp), - "X-Checksum": latest.checksum, - }, - ) - - -@router.delete("/{backup_id}", response_model=dict) -async def delete_backup( - backup_id: str, - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> dict[str, bool]: - """Delete a specific backup by ID.""" - result = await db.execute( - select(BackupMetadataModel).where( - BackupMetadataModel.id == backup_id, - BackupMetadataModel.user_id == current_user.id, - ) - ) - target = result.scalar_one_or_none() - if target is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found") - - await _blob_store.delete(current_user.id, target.s3_key) - await db.delete(target) - await db.commit() - - return {"ok": True} diff --git a/app/api/routes/billing.py b/app/api/routes/billing.py deleted file mode 100644 index e8bdef2..0000000 --- a/app/api/routes/billing.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Billing routes: Stripe checkout, webhook, subscription management. - -Business logic lives in ``app.billing.stripe_service.StripeService``. -The route layer handles HTTP concerns (request parsing, response shaping) -and delegates everything else to the service singleton. -""" - -from __future__ import annotations - -from typing import Any - -from fastapi import APIRouter, Depends, Header, Request, status -from pydantic import BaseModel -from sqlalchemy.ext.asyncio import AsyncSession - -from app.api.deps import get_current_user -from app.billing.stripe_service import stripe_service -from app.db import get_session -from app.schemas import BillingTier, UserProfile - -router = APIRouter(prefix="/billing", tags=["billing"]) - - -# ── Request bodies ───────────────────────────────────────────────────── - -class _CheckoutRequest(BaseModel): - tier: BillingTier - - -# ── Routes ───────────────────────────────────────────────────────────── - -@router.post("/checkout", response_model=dict) -async def create_checkout( - body: _CheckoutRequest, - current_user: UserProfile = Depends(get_current_user), -) -> dict[str, str]: - """Create a Stripe checkout session for a tier upgrade. - - Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured. - """ - url = stripe_service.create_checkout_session(current_user.id, body.tier) - return {"checkout_url": url} - - -@router.post("/webhook", response_model=dict) -async def stripe_webhook( - request: Request, - stripe_signature: str = Header(default="", alias="Stripe-Signature"), - db: AsyncSession = Depends(get_session), -) -> dict[str, bool]: - """Handle Stripe webhook events. - - No JWT auth — authenticated via Stripe signature verification instead. - Returns 200 immediately when Stripe is not configured (local dev). - """ - payload = await request.body() - await stripe_service.handle_webhook(payload, stripe_signature, db) - return {"ok": True} - - -@router.get("/subscription", response_model=dict) -async def get_subscription( - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> dict[str, Any]: - """Return the current subscription info for the authenticated user.""" - sub = await stripe_service.get_subscription(current_user.id, db) - if sub is None: - return { - "tier": current_user.tier, - "status": "free", - "stripe_subscription_id": None, - "current_period_end": None, - } - return sub - - -@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK) -async def cancel_subscription( - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> dict[str, bool]: - """Cancel the active subscription.""" - await stripe_service.cancel_subscription(current_user.id, db) - return {"ok": True} diff --git a/app/api/routes/chat.py b/app/api/routes/chat.py deleted file mode 100644 index 6270d0e..0000000 --- a/app/api/routes/chat.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Chat routes: POST /chat (REST fallback). - -WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device). -""" - -from __future__ import annotations - -from fastapi import APIRouter, Depends -from fastapi.responses import JSONResponse - -from app.api.deps import get_current_user -from app.core.deep_agent import run_home -from app.schemas import ChatRequest, UserProfile - -router = APIRouter(prefix="/chat", tags=["chat"]) - - -@router.post("") -async def chat( - body: ChatRequest, - current_user: UserProfile = Depends(get_current_user), -) -> JSONResponse: - """REST fallback for home chat when websocket streaming is unavailable.""" - response = await run_home( - user_id=current_user.id, - message=body.message, - context=body.context.model_dump(), - ) - return JSONResponse(content={"response": response}) diff --git a/app/api/routes/device_ws.py b/app/api/routes/device_ws.py deleted file mode 100644 index e868c2d..0000000 --- a/app/api/routes/device_ws.py +++ /dev/null @@ -1,417 +0,0 @@ -"""Device WebSocket endpoint. - -Persistent connection from Electron devices to the backend. - - WS /api/v1/ws/device?token= - -Auth: JWT passed as ``?token=`` query parameter (Bearer header is not -available during the WebSocket handshake). - -Protocol: - 1. Client connects → JWT validated → connection accepted. - 2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``. - 3. Backend registers the connection in ``DeviceConnectionManager``. - 4. Session enters message dispatch loop + heartbeat. - -Incoming frame dispatch: - - ``tool_result`` → resolves a pending tool-call Future. - - ``journey_start`` → starts a guided setup journey session. - - ``journey_message`` → continues a journey conversation. - - ``pong`` → heartbeat acknowledgement (updates last-seen). - - unknown types → logged, ignored. - -Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s. - -On disconnect: - - Unregisters from DeviceConnectionManager. - - Marks all in-progress AgentRunLog rows for this user as ``error`` - with message "device disconnected". -""" - -from __future__ import annotations - -import asyncio -import json -import logging -from uuid import uuid4 - -from fastapi import APIRouter, WebSocket, WebSocketDisconnect -from jose import JWTError, jwt -from sqlalchemy import update - -from app.api.routes.agent_setup import handle_journey_message, handle_journey_start -from app.config.settings import settings -from app.core.agent_runner import trigger_pending_runs -from app.core.deep_agent import run_floating_stream, run_home_stream -from app.core.device_manager import device_manager -from app.core.memory_middleware import MemoryMiddleware -from app.core.output_formatter import StreamFormatter -from app.core.ws_context import clear_client_executor, set_client_executor -from app.db import async_session -from app.models import AgentRunLog -from app.schemas import WsFrameType - -logger = logging.getLogger(__name__) - -router = APIRouter(prefix="/ws", tags=["device-ws"]) - -_HEARTBEAT_INTERVAL = 30 # seconds -_PONG_TIMEOUT = 10 # seconds — grace window after a ping - - -@router.websocket("/device") -async def device_ws(websocket: WebSocket) -> None: - """Persistent WebSocket endpoint for Electron device connections. - - Authentication is via ``?token=`` query parameter. - """ - # ── 1. Authenticate before accepting ───────────────────────────── - token = websocket.query_params.get("token", "") - try: - payload = jwt.decode( - token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] - ) - user_id: str | None = payload.get("sub") - if not user_id: - raise JWTError("missing sub") - except JWTError: - await websocket.close(code=1008) # Policy Violation - return - - await websocket.accept() - - # ── 2. Await device_hello frame ─────────────────────────────────── - try: - raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0) - except (asyncio.TimeoutError, WebSocketDisconnect): - await websocket.close(code=1008) - return - - try: - hello = json.loads(raw) - if hello.get("type") != WsFrameType.device_hello: - raise ValueError("expected device_hello as first frame") - device_id: str = hello["device_id"] - agent_ids: list[str] = hello.get("agent_ids", []) - except (KeyError, ValueError, json.JSONDecodeError) as exc: - logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc) - await websocket.close(code=1008) - return - - # ── 3. Register connection ──────────────────────────────────────── - device_manager.register(user_id, device_id, websocket) - logger.info( - "device_ws: connected user=%s device=%s agents=%s", - user_id, - device_id, - agent_ids, - ) - - # Trigger any overdue agent runs now that the device is connected. - asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager)) - - # ── 4. Concurrent message loop + heartbeat ──────────────────────── - try: - await asyncio.gather( - _message_loop(websocket, user_id), - _heartbeat_loop(websocket), - ) - except WebSocketDisconnect: - pass - except Exception as exc: - logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc) - finally: - device_manager.unregister(user_id) - logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id) - await _mark_runs_disconnected(user_id) - - -# ── Message dispatch loop ───────────────────────────────────────────── - -async def _message_loop(websocket: WebSocket, user_id: str) -> None: - """Receive frames from Electron and dispatch to the appropriate handler.""" - async for raw in websocket.iter_text(): - try: - frame: dict = json.loads(raw) - except json.JSONDecodeError: - logger.warning("device_ws: invalid JSON from user=%s", user_id) - continue - - frame_type = frame.get("type") - - if frame_type == WsFrameType.tool_result: - call_id = frame.get("id") - if call_id: - device_manager.resolve_pending_call(user_id, call_id, frame) - else: - logger.warning( - "device_ws: tool_result missing id from user=%s", user_id - ) - - elif frame_type == WsFrameType.home_request: - asyncio.create_task( - _handle_home_request(websocket, user_id, frame) - ) - - elif frame_type == WsFrameType.floating_request: - asyncio.create_task( - _handle_floating_request(websocket, user_id, frame) - ) - - elif frame_type == WsFrameType.journey_start: - asyncio.create_task( - _handle_journey_start(websocket, user_id, frame) - ) - - elif frame_type == WsFrameType.journey_message: - asyncio.create_task( - _handle_journey_message(websocket, user_id, frame) - ) - - elif frame_type == "pong": - # Heartbeat ack — nothing to do, connection is alive. - pass - - else: - logger.debug( - "device_ws: unknown frame type %r from user=%s", frame_type, user_id - ) - - -# ── v3 Chat Handlers ────────────────────────────────────────────────── - -async def _make_ws_executor(websocket: WebSocket, user_id: str): - """Return a callback that sends tool_call frames and awaits tool_result.""" - async def _executor(payload: dict) -> dict: - payload["type"] = WsFrameType.tool_call - await websocket.send_text(json.dumps(payload)) - future = device_manager.create_pending_call(user_id, payload["id"]) - return await future - return _executor - - -async def _handle_home_request( - websocket: WebSocket, - user_id: str, - frame: dict, -) -> None: - """Handle a home_request frame — streams HomeFormatter output back on the socket.""" - request_id = frame.get("request_id") or str(uuid4()) - message: str = frame.get("message", "") - session_id: str = frame.get("session_id") or str(uuid4()) - logger.info( - "device_ws: home_request_start user=%s req=%s session=%s msg=%s", - user_id, - request_id, - session_id, - message[:200], - ) - - # ── Memory: enrich context before LLM call ──────────────────────── - async with async_session() as db: - memory = MemoryMiddleware(db) - memory_context = await memory.enrich_context( - user_id, - message, - trace_id=request_id, - session_id=session_id, - ) - - context: dict = { - "conversation_history": frame.get("conversation_history", []), - "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, - **memory_context, - } - - executor = await _make_ws_executor(websocket, user_id) - set_client_executor(executor) - response_chunks: list[str] = [] - try: - event_stream = run_home_stream(user_id, message, context) - formatter = StreamFormatter(request_id=request_id) - async for ws_frame in formatter.format(event_stream): - await websocket.send_text(ws_frame.model_dump_json()) - # Collect text chunks to build the full response for episode storage - if ws_frame.type == "stream_text": # type: ignore[union-attr] - response_chunks.append(ws_frame.chunk) # type: ignore[union-attr] - except Exception as exc: - logger.error( - "device_ws: home_request failed user=%s req=%s: %s", - user_id, request_id, exc, - ) - finally: - clear_client_executor() - - # ── Memory: store episode after response ────────────────────────── - async with async_session() as db: - memory = MemoryMiddleware(db) - await memory.store_episode( - user_id, session_id, message, "".join(response_chunks), trace_id=request_id - ) - logger.info( - "device_ws: home_request_end user=%s req=%s session=%s response_chars=%d", - user_id, - request_id, - session_id, - len("".join(response_chunks)), - ) - - -async def _handle_floating_request( - websocket: WebSocket, - user_id: str, - frame: dict, -) -> None: - """Handle a floating_request frame — streams FloatingFormatter output back on the socket.""" - request_id = frame.get("request_id") or str(uuid4()) - message: str = frame.get("message", "") - session_id: str = frame.get("session_id") or str(uuid4()) - scope: dict = frame.get("scope", {}) - logger.info( - "device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s", - user_id, - request_id, - session_id, - json.dumps(scope, ensure_ascii=True)[:200], - message[:200], - ) - - # ── Memory: enrich context before LLM call ──────────────────────── - async with async_session() as db: - memory = MemoryMiddleware(db) - memory_context = await memory.enrich_context( - user_id, - message, - trace_id=request_id, - session_id=session_id, - ) - - context: dict = { - "scope": scope, - "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, - **memory_context, - } - - executor = await _make_ws_executor(websocket, user_id) - set_client_executor(executor) - response_chunks: list[str] = [] - try: - event_stream = run_floating_stream(user_id, message, context) - formatter = StreamFormatter(request_id=request_id) - async for ws_frame in formatter.format(event_stream): - await websocket.send_text(ws_frame.model_dump_json()) - if ws_frame.type == "stream_text": # type: ignore[union-attr] - response_chunks.append(ws_frame.chunk) # type: ignore[union-attr] - except Exception as exc: - logger.error( - "device_ws: floating_request failed user=%s req=%s: %s", - user_id, request_id, exc, - ) - finally: - clear_client_executor() - - # ── Memory: store episode after response ────────────────────────── - async with async_session() as db: - memory = MemoryMiddleware(db) - await memory.store_episode( - user_id, session_id, message, "".join(response_chunks), trace_id=request_id - ) - logger.info( - "device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d", - user_id, - request_id, - session_id, - len("".join(response_chunks)), - ) - - -# ── v4 Journey Handlers ───────────────────────────────────────────── - - -async def _handle_journey_start( - websocket: WebSocket, - user_id: str, - frame: dict, -) -> None: - """Handle a journey_start frame — explores directory and sends first question.""" - executor = await _make_ws_executor(websocket, user_id) - set_client_executor(executor) - try: - reply = await handle_journey_start(user_id, frame) - await websocket.send_text(json.dumps(reply)) - except Exception as exc: - logger.error( - "device_ws: journey_start failed user=%s: %s", user_id, exc - ) - await websocket.send_text(json.dumps({ - "type": "journey_reply", - "session_id": frame.get("session_id", ""), - "message": f"Failed to start journey: {exc}", - "done": True, - "prompt_template": None, - })) - finally: - clear_client_executor() - - -async def _handle_journey_message( - websocket: WebSocket, - user_id: str, - frame: dict, -) -> None: - """Handle a journey_message frame — continues the journey conversation.""" - executor = await _make_ws_executor(websocket, user_id) - set_client_executor(executor) - try: - reply = await handle_journey_message(user_id, frame) - await websocket.send_text(json.dumps(reply)) - except Exception as exc: - session_id = frame.get("session_id", "") - logger.error( - "device_ws: journey_message failed user=%s session=%s: %s", - user_id, session_id, exc, - ) - await websocket.send_text(json.dumps({ - "type": "journey_reply", - "session_id": session_id, - "message": f"Journey error: {exc}", - "done": True, - "prompt_template": None, - })) - finally: - clear_client_executor() - - -# ── Heartbeat ───────────────────────────────────────────────────────── - -async def _heartbeat_loop(websocket: WebSocket) -> None: - """Send a ping frame every 30 s to keep the connection alive.""" - while True: - await asyncio.sleep(_HEARTBEAT_INTERVAL) - await websocket.send_text(json.dumps({"type": "ping"})) - - -# ── Disconnect cleanup ──────────────────────────────────────────────── - -async def _mark_runs_disconnected(user_id: str) -> None: - """Mark all in-progress AgentRunLog rows as 'error' for this user.""" - try: - async with async_session() as db: - await db.execute( - update(AgentRunLog) - .where( - AgentRunLog.user_id == user_id, - AgentRunLog.status == "running", - ) - .values( - status="error", - errors=["device disconnected"], - ) - ) - await db.commit() - except Exception as exc: - logger.error( - "device_ws: failed to mark runs as disconnected for user=%s: %s", - user_id, - exc, - ) diff --git a/app/api/routes/plugins.py b/app/api/routes/plugins.py deleted file mode 100644 index f3a2e6e..0000000 --- a/app/api/routes/plugins.py +++ /dev/null @@ -1,148 +0,0 @@ -"""Plugins routes: browse and install plugins from the marketplace. - -Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that -persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables. -""" - -from __future__ import annotations - -from typing import Any, Literal - -from fastapi import APIRouter, Depends, HTTPException, Query, status -from pydantic import BaseModel -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.api.deps import get_current_user -from app.db import get_session -from app.marketplace.plugin_registry import registry -from app.marketplace.revenue_share import revenue_share -from app.models import PluginInstallation, PluginReview as PluginReviewModel -from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile - -router = APIRouter(prefix="/plugins", tags=["plugins"]) - - -# ── Tier gate ───────────────────────────────────────────────────────── - -def _require_plugin_tier(user: UserProfile) -> None: - """Raise HTTP 403 for users below Power tier.""" - if user.tier not in ("power", "team"): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Plugin marketplace requires Power tier or above", - ) - - -# ── Local detail schema ──────────────────────────────────────────────── - -class _PluginDetail(BaseModel): - plugin: PluginManifest - install_count: int - ratings: list[Any] - - -# ── Routes ──────────────────────────────────────────────────────────── - -@router.get("", response_model=PluginListResponse) -async def list_plugins( - category: str | None = Query(default=None), - q: str | None = Query(default=None), - page: int = Query(default=1, ge=1), - sort: Literal["rating", "installs", "newest"] = Query(default="newest"), - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> PluginListResponse: - """Browse the plugin marketplace. Requires Power tier or above.""" - _require_plugin_tier(current_user) - return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort) - - -@router.get("/{plugin_id}", response_model=_PluginDetail) -async def get_plugin( - plugin_id: str, - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> _PluginDetail: - """Get full plugin details including install count. Requires Power tier or above.""" - _require_plugin_tier(current_user) - entry = await registry.get_plugin(db, plugin_id) - if entry is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found") - - # Fetch review ratings for this plugin - review_result = await db.execute( - select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id) - ) - reviews = review_result.scalars().all() - ratings = [ - { - "reviewer_id": r.reviewer_id, - "decision": r.decision, - "notes": r.notes, - "reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None, - } - for r in reviews - ] - - return _PluginDetail( - plugin=entry["manifest"], - install_count=entry["install_count"], - ratings=ratings, - ) - - -@router.post("/{plugin_id}/install", response_model=dict) -async def install_plugin( - plugin_id: str, - body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> dict[str, Any]: - """Install a plugin. Triggers Stripe Connect revenue split for paid plugins. - - Requires Power tier or above. - """ - _require_plugin_tier(current_user) - entry = await registry.get_plugin(db, plugin_id) - if entry is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found") - - # Record the installation in plugin_installations - installation = PluginInstallation( - plugin_id=plugin_id, - user_id=current_user.id, - ) - db.add(installation) - await db.flush() - - await revenue_share.record_install( - db, - plugin_id=plugin_id, - user_id=current_user.id, - amount_cents=entry["manifest"].price_cents, - ) - - download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip" - return {"ok": True, "download_url": download_url} - - -@router.delete("/{plugin_id}/install", response_model=dict) -async def uninstall_plugin( - plugin_id: str, - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> dict[str, bool]: - """Unregister a plugin installation.""" - result = await db.execute( - select(PluginInstallation).where( - PluginInstallation.plugin_id == plugin_id, - PluginInstallation.user_id == current_user.id, - ) - ) - installation = result.scalar_one_or_none() - if installation is not None: - await db.delete(installation) - await db.commit() - await registry.record_uninstall(db, plugin_id) - return {"ok": True} diff --git a/app/api/routes/storage.py b/app/api/routes/storage.py deleted file mode 100644 index ae71abd..0000000 --- a/app/api/routes/storage.py +++ /dev/null @@ -1,195 +0,0 @@ -"""Storage routes: CRUD for E2E-encrypted cloud records. - -Blobs are stored in S3 via BlobStore. Record metadata is persisted in the -PostgreSQL ``storage_records`` table. -""" - -from __future__ import annotations - -import uuid - -from fastapi import APIRouter, Depends, HTTPException, Query, Response, status -from pydantic import BaseModel -from sqlalchemy import func, select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.api.deps import get_current_user -from app.billing.tier_manager import tier_manager -from app.db import get_session -from app.models import StorageRecord -from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile -from app.storage.blob_store import BlobStore -from app.storage.encryption import reject_if_tampered - -router = APIRouter(prefix="/storage", tags=["storage"]) - -_blob_store = BlobStore() - - -# ── Local response schemas ───────────────────────────────────────────── - -class _CreateResponse(BaseModel): - id: str - created_at: int - - -class _RecordMeta(BaseModel): - id: str - table: str - checksum: str - created_at: int - updated_at: int - - -# ── Helpers ──────────────────────────────────────────────────────────── - -async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int: - """Return total bytes stored by *user_id*.""" - result = await db.execute( - select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where( - StorageRecord.user_id == user_id - ) - ) - return int(result.scalar_one()) - - -async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None: - """Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit.""" - current = await _current_usage_bytes(user.id, db) - tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes) - - -async def _get_record_for_user( - record_id: str, user_id: str, db: AsyncSession -) -> StorageRecord: - """Look up a record and verify ownership. Returns 404 on mismatch - to prevent user enumeration attacks.""" - result = await db.execute( - select(StorageRecord).where( - StorageRecord.id == record_id, StorageRecord.user_id == user_id - ) - ) - record = result.scalar_one_or_none() - if record is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found") - return record - - -# ── Routes ───────────────────────────────────────────────────────────── - -@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED) -async def create_record( - body: StorageRecordCreate, - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> _CreateResponse: - """Upload a new E2E-encrypted blob. Verifies checksum before storing.""" - reject_if_tampered(body.blob, body.checksum) - await _check_quota(current_user, len(body.blob), db) - - record_id = str(uuid.uuid4()) - - s3_key = await _blob_store.upload( - current_user.id, body.table, record_id, body.blob, body.checksum - ) - - record = StorageRecord( - id=record_id, - user_id=current_user.id, - table_name=body.table, - s3_key=s3_key, - checksum=body.checksum, - size_bytes=len(body.blob), - ) - db.add(record) - await db.commit() - await db.refresh(record) - - created_at_ms = int(record.created_at.timestamp() * 1000) - return _CreateResponse(id=record_id, created_at=created_at_ms) - - -@router.get("/records", response_model=list[_RecordMeta]) -async def list_records( - table: str | None = Query(default=None), - page: int = Query(default=1, ge=1), - limit: int = Query(default=50, ge=1, le=200), - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> list[_RecordMeta]: - """List record metadata for the authenticated user. Blob bytes are never returned.""" - query = select(StorageRecord).where(StorageRecord.user_id == current_user.id) - if table is not None: - query = query.where(StorageRecord.table_name == table) - query = query.offset((page - 1) * limit).limit(limit) - - result = await db.execute(query) - rows = result.scalars().all() - - return [ - _RecordMeta( - id=r.id, - table=r.table_name, - checksum=r.checksum, - created_at=int(r.created_at.timestamp() * 1000), - updated_at=int(r.updated_at.timestamp() * 1000), - ) - for r in rows - ] - - -@router.get("/records/{record_id}") -async def download_record( - record_id: str, - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> Response: - """Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header.""" - record = await _get_record_for_user(record_id, current_user.id, db) - blob = await _blob_store.download(current_user.id, record.s3_key) - return Response( - content=blob, - media_type="application/octet-stream", - headers={"X-Checksum": record.checksum}, - ) - - -@router.put("/records/{record_id}", response_model=dict) -async def update_record( - record_id: str, - body: StorageRecordUpdate, - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> dict[str, bool]: - """Replace the blob for an existing record. Verifies checksum before storing.""" - record = await _get_record_for_user(record_id, current_user.id, db) - reject_if_tampered(body.blob, body.checksum) - - delta = len(body.blob) - record.size_bytes - if delta > 0: - await _check_quota(current_user, delta, db) - - s3_key = await _blob_store.upload( - current_user.id, record.table_name, record_id, body.blob, body.checksum - ) - - record.s3_key = s3_key - record.checksum = body.checksum - record.size_bytes = len(body.blob) - await db.commit() - - return {"ok": True} - - -@router.delete("/records/{record_id}", response_model=dict) -async def delete_record( - record_id: str, - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> dict[str, bool]: - """Delete a record and its S3 blob.""" - record = await _get_record_for_user(record_id, current_user.id, db) - await _blob_store.delete(current_user.id, record.s3_key) - await db.delete(record) - await db.commit() - return {"ok": True} diff --git a/app/api/routes/vectors.py b/app/api/routes/vectors.py deleted file mode 100644 index a03e602..0000000 --- a/app/api/routes/vectors.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Vectors routes: upsert, search, delete cloud vector store entries, and embed text.""" - -from __future__ import annotations - -from fastapi import APIRouter, Depends -from pydantic import BaseModel - -from app.api.deps import get_current_user -from app.core.llm import embed -from app.schemas import ( - UserProfile, - VectorSearchRequest, - VectorSearchResponse, - VectorUpsertRequest, -) -from app.storage.encryption import reject_if_tampered -from app.storage.vector_store import VectorStore - -router = APIRouter(prefix="/storage", tags=["vectors"]) - -_vector_store = VectorStore() - - -class _VectorDeleteRequest(BaseModel): - ids: list[str] - - -class _EmbedRequest(BaseModel): - text: str - - -class _EmbedResponse(BaseModel): - vector: list[float] - - -@router.post("/vectors/upsert", response_model=dict) -async def upsert_vectors( - body: VectorUpsertRequest, - current_user: UserProfile = Depends(get_current_user), -) -> dict[str, int]: - """Verify checksums and store encrypted vectors in the user-scoped namespace.""" - for item in body.vectors: - reject_if_tampered(item.blob, item.checksum) - await _vector_store.upsert(current_user.id, body.vectors) - return {"upserted": len(body.vectors)} - - -@router.post("/vectors/search", response_model=VectorSearchResponse) -async def search_vectors( - body: VectorSearchRequest, - current_user: UserProfile = Depends(get_current_user), -) -> VectorSearchResponse: - """Search the user-scoped vector namespace with an encrypted query blob.""" - results = await _vector_store.search(current_user.id, body.query_blob, body.top_k) - return VectorSearchResponse(results=results) - - -@router.delete("/vectors", response_model=dict) -async def delete_vectors( - body: _VectorDeleteRequest, - current_user: UserProfile = Depends(get_current_user), -) -> dict[str, bool]: - """Delete vectors by ID, scoped to the authenticated user.""" - await _vector_store.delete(current_user.id, body.ids) - return {"ok": True} - - -@router.post("/vectors/embed", response_model=_EmbedResponse) -async def embed_text( - body: _EmbedRequest, - current_user: UserProfile = Depends(get_current_user), -) -> _EmbedResponse: - """Generate a 1536-dim embedding vector for the given text. - - Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT). - Used by backend tools (note_agent) and Electron (vectordb.ts) alike. - """ - vector = await embed(body.text) - return _EmbedResponse(vector=vector) diff --git a/app/billing/__init__.py b/app/billing/__init__.py deleted file mode 100644 index ef83f83..0000000 --- a/app/billing/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from app.billing.stripe_service import stripe_service -from app.billing.tier_manager import tier_manager - -__all__ = ["stripe_service", "tier_manager"] diff --git a/app/billing/stripe_service.py b/app/billing/stripe_service.py deleted file mode 100644 index 3bd9038..0000000 --- a/app/billing/stripe_service.py +++ /dev/null @@ -1,256 +0,0 @@ -"""Stripe service: checkout sessions, webhook handling, subscription management. - -Subscription records are persisted in the PostgreSQL ``subscriptions`` table. -All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not -configured, enabling local development without live credentials. -""" - -from __future__ import annotations - -from datetime import datetime, timezone -from typing import Any - -import stripe as stripe_lib -from fastapi import HTTPException, status -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.config.settings import settings - -# Stripe price IDs per tier — replace with real IDs in production .env -TIER_PRICE_IDS: dict[str, str] = { - "pro": "price_pro_monthly", - "power": "price_power_monthly", - "team": "price_team_monthly", -} - - -class StripeService: - """Wraps all Stripe interactions and owns subscription persistence.""" - - # ── Internal helpers ──────────────────────────────────────────────── - - def _configured(self) -> bool: - return bool(settings.STRIPE_SECRET_KEY) - - def _client(self) -> Any: - stripe_lib.api_key = settings.STRIPE_SECRET_KEY - return stripe_lib - - # ── Public API ────────────────────────────────────────────────────── - - def create_checkout_session( - self, - user_id: str, - tier: str, - success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}", - cancel_url: str = "https://app.adiuva.app/billing/cancel", - ) -> str: - """Create a Stripe checkout session and return the URL. - - Returns a stub URL when Stripe is not configured. - Raises ``HTTP 400`` for the free tier or an unknown tier. - """ - if tier == "free": - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Cannot create a checkout session for the free tier", - ) - - price_id = TIER_PRICE_IDS.get(tier) - if not price_id: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Unknown tier: {tier}", - ) - - if not self._configured(): - return "https://stripe.com/stub-checkout" - - s = self._client() - session = s.checkout.Session.create( - payment_method_types=["card"], - mode="subscription", - line_items=[{"price": price_id, "quantity": 1}], - success_url=success_url, - cancel_url=cancel_url, - metadata={"user_id": user_id, "tier": tier}, - ) - return session.url - - async def handle_webhook( - self, - payload: bytes, - sig_header: str, - db: AsyncSession, - ) -> None: - """Process a Stripe webhook event. - - Verifies the signature, then dispatches on event type. - Raises ``HTTP 400`` on signature mismatch. - No-ops when Stripe is not configured. - """ - if not self._configured(): - return - - try: - s = self._client() - event = s.Webhook.construct_event( - payload, sig_header, settings.STRIPE_WEBHOOK_SECRET - ) - except stripe_lib.error.SignatureVerificationError: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid Stripe signature", - ) - - event_type: str = event["type"] - data: dict[str, Any] = event["data"]["object"] - - if event_type == "checkout.session.completed": - user_id = data.get("metadata", {}).get("user_id") - tier = data.get("metadata", {}).get("tier", "free") - sub_id = data.get("subscription") - period_end_ts = data.get("current_period_end") - period_end = ( - datetime.fromtimestamp(period_end_ts, tz=timezone.utc) - if period_end_ts - else None - ) - if user_id: - await self._upsert_subscription( - db, user_id, sub_id, tier, "active", period_end - ) - - elif event_type == "customer.subscription.updated": - sub_id = data.get("id") - new_status = data.get("status", "active") - period_end_ts = data.get("current_period_end") - period_end = ( - datetime.fromtimestamp(period_end_ts, tz=timezone.utc) - if period_end_ts - else None - ) - if sub_id: - await self._update_subscription_by_stripe_id( - db, sub_id, status=new_status, current_period_end=period_end - ) - - elif event_type == "customer.subscription.deleted": - sub_id = data.get("id") - if sub_id: - await self._update_subscription_by_stripe_id( - db, sub_id, tier="free", status="canceled" - ) - - elif event_type == "invoice.payment_failed": - sub_id = data.get("subscription") - if sub_id: - await self._update_subscription_by_stripe_id( - db, sub_id, status="past_due" - ) - - await db.commit() - - async def get_subscription( - self, user_id: str, db: AsyncSession - ) -> dict[str, Any] | None: - """Return the subscription record for ``user_id``, or ``None`` if absent.""" - from app.models import Subscription # noqa: PLC0415 - - result = await db.execute( - select(Subscription).where(Subscription.user_id == user_id) - ) - sub = result.scalar_one_or_none() - if sub is None: - return None - return { - "tier": sub.tier, - "stripe_subscription_id": sub.stripe_subscription_id, - "status": sub.status, - "current_period_end": ( - int(sub.current_period_end.timestamp() * 1000) - if sub.current_period_end - else None - ), - } - - async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None: - """Cancel the user's Stripe subscription and downgrade them to free. - - Raises ``HTTP 404`` when no active subscription exists. - """ - from app.models import Subscription # noqa: PLC0415 - - result = await db.execute( - select(Subscription).where(Subscription.user_id == user_id) - ) - sub = result.scalar_one_or_none() - if sub is None or not sub.stripe_subscription_id: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="No active subscription found", - ) - - if self._configured(): - s = self._client() - s.Subscription.cancel(sub.stripe_subscription_id) - - sub.tier = "free" - sub.status = "canceled" - await db.commit() - - # ── Private DB helpers ─────────────────────────────────────────────── - - async def _upsert_subscription( - self, - db: AsyncSession, - user_id: str, - stripe_subscription_id: str | None, - tier: str, - sub_status: str, - current_period_end: datetime | None, - ) -> None: - from app.models import Subscription # noqa: PLC0415 - - result = await db.execute( - select(Subscription).where(Subscription.user_id == user_id) - ) - sub = result.scalar_one_or_none() - if sub is None: - sub = Subscription(user_id=user_id) - db.add(sub) - sub.stripe_subscription_id = stripe_subscription_id - sub.tier = tier - sub.status = sub_status - sub.current_period_end = current_period_end - - async def _update_subscription_by_stripe_id( - self, - db: AsyncSession, - stripe_subscription_id: str, - *, - tier: str | None = None, - status: str | None = None, - current_period_end: datetime | None = None, - ) -> None: - from app.models import Subscription # noqa: PLC0415 - - result = await db.execute( - select(Subscription).where( - Subscription.stripe_subscription_id == stripe_subscription_id - ) - ) - sub = result.scalar_one_or_none() - if sub is None: - return - if tier is not None: - sub.tier = tier - if status is not None: - sub.status = status - if current_period_end is not None: - sub.current_period_end = current_period_end - - -# Module-level singleton shared across the app. -stripe_service = StripeService() diff --git a/app/billing/tier_manager.py b/app/billing/tier_manager.py deleted file mode 100644 index ed5f3de..0000000 --- a/app/billing/tier_manager.py +++ /dev/null @@ -1,195 +0,0 @@ -"""Tier manager: feature matrix and quota enforcement. - -``TierManager`` is the single source of truth for what each billing tier -allows. ``get_tier`` queries the ``subscriptions`` table for the live tier. -Quota-enforcement helpers take ``tier`` directly — the caller already has it -from ``current_user.tier`` (provided by ``get_current_user``). -""" - -from __future__ import annotations - -from typing import Any - -from fastapi import HTTPException, status -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from app.schemas import BillingTier - -# Feature matrix per tier. -1 means unlimited; 0 means disabled. -FEATURES: dict[str, dict[str, Any]] = { - "free": { - "agents": 3, - "batch_active": 2, - "batch_runs_per_day": 5, - "cloud_storage_gb": 0, - "backup_gb": 0, - "providers": 1, - "batch_builder": False, - "plugin_marketplace": False, - "sso": False, - }, - "pro": { - "agents": -1, # unlimited - "batch_active": 10, - "batch_runs_per_day": 50, - "cloud_storage_gb": 5, - "backup_gb": 5, - "providers": -1, - "batch_builder": False, - "plugin_marketplace": False, - "sso": False, - }, - "power": { - "agents": -1, - "batch_active": -1, # unlimited - "batch_runs_per_day": -1, # unlimited - "cloud_storage_gb": 25, - "backup_gb": 25, - "providers": -1, - "batch_builder": True, - "plugin_marketplace": True, - "sso": False, - }, - "team": { - "agents": -1, - "batch_active": -1, - "batch_runs_per_day": -1, # unlimited - "cloud_storage_gb": -1, # unlimited - "backup_gb": -1, # unlimited - "providers": -1, - "batch_builder": True, - "plugin_marketplace": True, - "sso": True, - }, -} - -# Requests-per-minute limit per tier. -RATE_LIMITS: dict[str, int] = { - "free": 20, - "pro": 60, - "power": 120, - "team": 200, -} - - -class TierManager: - """Centralises tier feature-gating, rate-limit lookups, and quota checks.""" - - # ── Tier lookup ───────────────────────────────────────────────────── - - async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier: - """Return the current billing tier for ``user_id`` from the DB. - - Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod - when no subscription row exists. - """ - from app.models import Subscription # noqa: PLC0415 - from app.config.settings import settings # noqa: PLC0415 - - result = await db.execute( - select(Subscription.tier).where(Subscription.user_id == user_id) - ) - tier: str | None = result.scalar_one_or_none() - if tier is None or tier not in FEATURES: - return "power" if settings.ENV == "dev" else "free" - return tier # type: ignore[return-value] - - # ── Feature access ─────────────────────────────────────────────────── - - def check_feature(self, tier: BillingTier, feature: str) -> bool: - """Return ``True`` if ``tier`` has ``feature`` enabled. - - For numeric features, any value > 0 or -1 (unlimited) counts as enabled. - """ - value = FEATURES.get(tier, FEATURES["free"]).get(feature) - if value is None: - return False - if isinstance(value, bool): - return value - return value != 0 - - def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None: - """Raise ``HTTP 403`` if ``tier`` does not have ``feature``.""" - if not self.check_feature(tier, feature): - detail = ( - f"Feature '{feature}' requires {tier_name} tier or above." - if tier_name - else f"Feature '{feature}' is not available on your current tier." - ) - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail) - - # ── Rate limiting ──────────────────────────────────────────────────── - - def get_rate_limit(self, tier: BillingTier) -> int: - """Return the requests-per-minute limit for ``tier``.""" - return RATE_LIMITS.get(tier, RATE_LIMITS["free"]) - - # ── Storage quota ──────────────────────────────────────────────────── - - def enforce_quota( - self, - tier: BillingTier, - current_bytes: int = 0, - additional_bytes: int = 0, - ) -> None: - """Raise ``HTTP 402`` if the user would exceed their cloud storage quota. - - ``tier`` is the caller's current tier (from ``current_user.tier``). - ``current_bytes`` is the total bytes already stored (queried by caller). - """ - limit_gb: int = FEATURES[tier]["cloud_storage_gb"] - if limit_gb == 0: - raise HTTPException( - status_code=status.HTTP_402_PAYMENT_REQUIRED, - detail=f"Cloud storage is not available on the '{tier}' tier", - ) - if limit_gb == -1: - return # unlimited - limit_bytes = limit_gb * 1024 ** 3 - if current_bytes + additional_bytes > limit_bytes: - raise HTTPException( - status_code=status.HTTP_402_PAYMENT_REQUIRED, - detail=f"Storage quota exceeded for tier '{tier}'", - ) - - def enforce_backup_quota( - self, - tier: BillingTier, - current_bytes: int = 0, - additional_bytes: int = 0, - ) -> None: - """Raise ``HTTP 402`` if the user would exceed their backup quota.""" - limit_gb: int = FEATURES[tier]["backup_gb"] - if limit_gb == 0: - raise HTTPException( - status_code=status.HTTP_402_PAYMENT_REQUIRED, - detail=f"Backup is not available on the '{tier}' tier", - ) - if limit_gb == -1: - return # unlimited - limit_bytes = limit_gb * 1024 ** 3 - if current_bytes + additional_bytes > limit_bytes: - raise HTTPException( - status_code=status.HTTP_402_PAYMENT_REQUIRED, - detail=f"Backup quota exceeded for tier '{tier}'", - ) - - def check_quota( - self, - tier: BillingTier, - current_bytes: int = 0, - additional_bytes: int = 0, - ) -> bool: - """Return ``True`` if the user can store ``additional_bytes`` more data.""" - limit_gb: int = FEATURES[tier]["cloud_storage_gb"] - if limit_gb == 0: - return False - if limit_gb == -1: - return True - limit_bytes = limit_gb * 1024 ** 3 - return current_bytes + additional_bytes <= limit_bytes - - -# Module-level singleton shared across the app. -tier_manager = TierManager() diff --git a/app/config/__init__.py b/app/config/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/app/config/settings.py b/app/config/settings.py deleted file mode 100644 index e566969..0000000 --- a/app/config/settings.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import Literal -from pydantic_settings import BaseSettings, SettingsConfigDict - - -class Settings(BaseSettings): - DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva" - JWT_SECRET: str = "change-me-in-production" - JWT_ALGORITHM: str = "HS256" - JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 - JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 30 - - STRIPE_SECRET_KEY: str = "" - STRIPE_WEBHOOK_SECRET: str = "" - - S3_BUCKET: str = "" - S3_REGION: str = "us-east-1" - S3_ENDPOINT_URL: str = "" - AWS_ACCESS_KEY_ID: str = "" - AWS_SECRET_ACCESS_KEY: str = "" - - PINECONE_API_KEY: str = "" - PINECONE_INDEX: str = "adiuva" - QDRANT_URL: str = "" - QDRANT_API_KEY: str = "" - - OPENAI_API_KEY: str = "" - ANTHROPIC_API_KEY: str = "" - GOOGLE_API_KEY: str = "" - CEREBRAS_API_KEY: str = "" - GITHUB_TOKEN: str = "" - - LLM_MODEL: str = "gpt-4o" - LLM_EMBED_MODEL: str = "text-embedding-3-small" - - # GitHub Copilot OAuth token storage directory. - # Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot). - # In Docker, set this to a path backed by a named volume so tokens survive restarts. - GITHUB_COPILOT_TOKEN_DIR: str = "" - - # OAuth client credentials — used for Gmail and Microsoft (Outlook/Teams) flows. - GMAIL_CLIENT_ID: str = "" - GMAIL_CLIENT_SECRET: str = "" - MS_CLIENT_ID: str = "" - MS_CLIENT_SECRET: str = "" - # MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts). - MS_TENANT_ID: str = "common" - - # Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth - # tokens stored in cloud_agent_configs.oauth_token_encrypted. - # Generate with: from cryptography.fernet import Fernet; Fernet.generate_key() - OAUTH_ENCRYPTION_KEY: str = "" - - CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"] - - ENV: Literal["dev", "prod"] = "dev" - - model_config = SettingsConfigDict( - env_file=".env", env_file_encoding="utf-8", extra="ignore" - ) - - -settings = Settings() diff --git a/app/core/__init__.py b/app/core/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/app/core/agent_registry.py b/app/core/agent_registry.py deleted file mode 100644 index 95c2033..0000000 --- a/app/core/agent_registry.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Minimal agent base types retained for compatibility with batch runners.""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from typing import Any - - -class BaseAgent(ABC): - """Common base for non-chat agents still using the old base contract.""" - - def __init__( - self, - user_id: str = "", - shared_memory: dict[str, Any] | None = None, - vector_store_context: list[str] | None = None, - ) -> None: - self.user_id = user_id - self.shared_memory: dict[str, Any] = shared_memory or {} - self.vector_store_context: list[str] = vector_store_context or [] - - @abstractmethod - def get_name(self) -> str: ... - - @abstractmethod - def get_description(self) -> str: ... - - @property - def skills(self) -> list[str]: - return [] diff --git a/app/core/agent_runner.py b/app/core/agent_runner.py deleted file mode 100644 index c11324e..0000000 --- a/app/core/agent_runner.py +++ /dev/null @@ -1,1064 +0,0 @@ -"""Agent run orchestrator. - -Drives two agent types: - -* **Local directory agent** — two-step execution per file: - Step 1 (Classification) uses code to fetch all projects and asks the LLM - to identify which project the file belongs to and which domains are relevant. - Step 2 (Processing) fetches existing entities for that project/domains via - code and runs an LLM with tools — existing data in context enforces - update-first naturally. - -* **Cloud connector agent** — fetches data from third-party APIs (Gmail, - Teams, Outlook) and pushes extracted items to Electron. - -Usage ------ -Background tasks are spawned with ``asyncio.create_task()``:: - - asyncio.create_task(run_local_agent(user_id, config, run_log, device_manager)) - asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager)) - -The ``trigger_pending_runs`` function is called by the device WS endpoint -when Electron sends ``device_hello``, so any overdue runs fire immediately -when the device reconnects. -""" - -from __future__ import annotations - -import asyncio -import json -import logging -import uuid -from datetime import datetime, timedelta, timezone -from typing import Any - -from croniter import croniter -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage -from sqlalchemy import select - -from app.agents.filesystem_agent import FILESYSTEM_TOOLS -from app.agents.note_agent import NOTE_TOOLS -from app.agents.project_agent import PROJECT_TOOLS -from app.agents.task_agent import TASK_TOOLS -from app.agents.timeline_agent import TIMELINE_TOOLS -from app.core.device_manager import DeviceConnectionManager -from app.core.llm import get_llm -from app.core.ws_context import clear_client_executor, execute_on_client, set_client_executor -from app.db import async_session -from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig - -logger = logging.getLogger(__name__) - -# ── Concurrency guard ───────────────────────────────────────────────────── -# Tracks agent IDs that currently have a run in progress. -# Prevents multiple simultaneous runs of the same agent within a single process. -_running_agents: set[str] = set() - - -def is_agent_running(agent_id: str) -> bool: - """Return ``True`` if *agent_id* already has a run in progress.""" - return agent_id in _running_agents - -# ── Timeouts ─────────────────────────────────────────────────────────────── - -# Max seconds to wait for a single tool-call round-trip (FE → BE). -_TOOL_CALL_TIMEOUT: int = 30 -# Max LLM reasoning steps for Step 2 processing. -_MAX_PROCESSING_STEPS: int = 12 -# Max directory recursion depth during scan. -_MAX_SCAN_DEPTH: int = 5 - -# ── Data-type to tool mapping ───────────────────────────────────────────── -# NOTE: "projects" is intentionally excluded — project creation/assignment is -# handled in code by the runner, never delegated to the Step 2 LLM. - -_DATA_TYPE_TOOLS: dict[str, list[Any]] = { - "tasks": TASK_TOOLS, - "notes": NOTE_TOOLS, - "timelines": TIMELINE_TOOLS, -} - -# ── Step 1: Classification prompt ───────────────────────────────────────── - -_DOMAIN_DESCRIPTIONS: dict[str, str] = { - "tasks": ( - "Action items, to-dos, deliverables — anything that describes work to be done, " - "assigned to someone, or tracked with a due date or status." - ), - "notes": ( - "Documentation, meeting notes, summaries, reference material — " - "written content meant to be read and referenced rather than acted on." - ), - "timelines": ( - "Project milestones, deadlines, scheduled events — " - "specific dates that mark a point in the progress of a project." - ), - "projects": ( - "High-level project entities — only relevant if the file clearly introduces " - "a new project or updates the scope of an existing one." - ), -} - -_STEP1_SYSTEM_PROMPT = """\ -You are a file classifier for a freelance project management tool. - -Your job is to match a file to an existing project and identify which data domains to extract. - -## Project matching rules (STRICT — follow in order) - -1. Search the file content for any mention of a project name, client name, acronym, or topic - that overlaps with the existing projects listed below. -2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough. -3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort - when the file has zero meaningful connection to any listed project. -4. When in doubt, pick the closest match from the list. - -## Response format - -Respond ONLY with a JSON object — no markdown, no explanation: - -{{"project_id": "", "new_project_name": "", "domains": ["tasks", "notes"]}} - -## Domain definitions (only consider domains in the allowed list) - -{domain_definitions} - -## Existing projects - -{projects_list} -""" - -# ── Step 2: Processing prompt ───────────────────────────────────────────── - -_PROCESSING_SYSTEM_PROMPT = """\ -You are a data extraction assistant for a freelance project management tool. - -Your task: extract structured data from the file content and persist it using the available tools. - -## Mandatory process — follow this order for EVERY item you extract - -1. READ the existing records listed below for the relevant domain. -2. SEARCH for a match by title, topic, or semantic similarity. -3. If a match exists → call the update_* tool with the existing record's id. -4. If no match exists → call the create_* tool and set isAiSuggested=1. - -NEVER call create_* without first checking the existing records. -NEVER duplicate a record that already exists under a different wording. - -## Existing records (source of truth) - -{existing_context} - -## Context - -Project: {project_context} -Domains to extract: {data_types} - -{custom_prompt_section} -""" - -# ── Cloud processing prompt (kept separate for cloud agent) ─────────────── - -_CLOUD_PROCESSING_PROMPT = """\ -You are a data extraction and management assistant for a freelance project -management tool. - -Available tools: - Filesystem : read_file_content, list_directory, get_file_metadata - Tasks : list_tasks, create_task, update_task, add_task_comment - Notes : list_notes, get_note, create_note, update_note - Timelines : list_timelines, create_timeline, update_timeline - Projects : list_all_projects, get_project, create_project, update_project - -Your task: -1. Read the full content of each file below using read_file_content. -2. For each piece of information found, ALWAYS try to match and update an - existing record before creating a new one. -3. ONLY act on these entity types: {data_types}. -4. Do NOT invent data. Only extract what is clearly present in the files. -5. If a file contains no relevant data for the target entity types, skip it. - -{project_context} - -Files to process: -{file_list} - -{custom_prompt_section} - -After processing all files, respond with a brief summary of what you updated -and what you created. -""" - - -# ── Cron helper ──────────────────────────────────────────────────────────── - - -def _is_overdue(schedule_cron: str, last_run_at: datetime | None) -> bool: - """Return ``True`` if the next scheduled run time has already passed. - - Always validates the cron expression first — an invalid expression returns - ``False`` (fail-safe: never trigger an unparseable schedule). - """ - try: - now = datetime.now(timezone.utc) - if last_run_at is None: - croniter(schedule_cron, now) - return True - ts = last_run_at - if ts.tzinfo is None: - ts = ts.replace(tzinfo=timezone.utc) - cron = croniter(schedule_cron, ts) - next_run: datetime = cron.get_next(datetime) - return now >= next_run - except Exception as exc: - logger.warning("agent_runner: cannot parse cron %r: %s", schedule_cron, exc) - return False - - -# ── WS executor for agent context ───────────────────────────────────────── - - -def _make_agent_executor( - user_id: str, - device_mgr: DeviceConnectionManager, - run_context: dict | None = None, -) -> Any: - """Create a WS callback for ``set_client_executor()`` so that all tools - can use ``execute_on_client()`` during an agent run. - - If *run_context* is provided it is attached to every ``tool_call`` frame - so the Electron client can attribute actions to the correct agent run. - """ - async def _executor(payload: dict) -> dict: - payload["type"] = "tool_call" - if run_context: - payload["run_context"] = run_context - call_id = payload["id"] - fut = device_mgr.create_pending_call(user_id, call_id) - await device_mgr.send_frame(user_id, payload) - return await asyncio.wait_for(fut, timeout=_TOOL_CALL_TIMEOUT) - return _executor - - -# ── LLM tool-calling loop ───────────────────────────────────────────────── - - -def _as_text(content: Any) -> str: - if content is None: - return "" - if isinstance(content, str): - return content - if isinstance(content, list): - parts: list[str] = [] - for item in content: - if isinstance(item, str): - parts.append(item) - elif isinstance(item, dict): - text = item.get("text") - if isinstance(text, str): - parts.append(text) - return "".join(parts) - return str(content) - - -async def _run_agent_with_tools( - *, - system_prompt: str, - user_message: str, - tools: list[Any], - max_steps: int, -) -> str: - """Run an LLM agent with tool-calling, returning the final text response.""" - llm = get_llm() - llm_with_tools = llm.bind_tools(tools) - messages: list[Any] = [ - SystemMessage(content=system_prompt), - HumanMessage(content=user_message), - ] - - tool_map = {tool_def.name: tool_def for tool_def in tools} - - for _ in range(max_steps): - response: AIMessage = await llm_with_tools.ainvoke(messages) - messages.append(response) - - if not response.tool_calls: - return _as_text(response.content) - - for call in response.tool_calls: - call_id = str(call.get("id", "")) - call_name = str(call.get("name", "")) - call_args = call.get("args", {}) - logger.info( - "agent_runner: tool_call name=%s args=%s", - call_name, - json.dumps(call_args, ensure_ascii=True)[:800], - ) - - tool_fn = tool_map.get(call_name) - if tool_fn is None: - tool_output = f"Unknown tool: {call_name}" - else: - tool_output = await tool_fn.ainvoke(call_args) - - logger.info( - "agent_runner: tool_result name=%s output=%s", - call_name, - str(tool_output)[:200], - ) - messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) - - final = await llm.ainvoke(messages) - return _as_text(final.content) - - -# ── Tool list builder ───────────────────────────────────────────────────── - - -def _build_processing_tools(data_types: list[str]) -> list[Any]: - """Build the tool list for processing based on user's data_types selection.""" - tools: list[Any] = list(FILESYSTEM_TOOLS) - for dt in data_types: - dt_tools = _DATA_TYPE_TOOLS.get(dt) - if dt_tools: - tools.extend(dt_tools) - return tools - - -# ── Code-based directory scanner ───────────────────────────────────────── - - -async def _scan_directories( - paths: list[str], - extensions: list[str], - last_run_at: datetime | None, -) -> list[str]: - """Walk directories via WS tool calls and return filtered file paths. - - Recursion is capped at ``_MAX_SCAN_DEPTH``. Files are filtered by - extension (if configured) and by modification date (if ``last_run_at`` - is set). Fails open: if metadata cannot be read, the file is included. - """ - all_files: list[str] = [] - ext_set = {e.lstrip(".").lower() for e in extensions} if extensions else set() - - async def _walk(path: str, depth: int) -> None: - if depth > _MAX_SCAN_DEPTH: - return - try: - result = await execute_on_client(action="list_directory", data={"path": path}) - except Exception as exc: - logger.warning("agent_runner: list_directory failed %r: %s", path, exc) - return - for entry in result.get("entries", []): - entry_path = entry.get("path", "") - if not entry_path: - continue - if entry.get("type") == "directory": - await _walk(entry_path, depth + 1) - elif entry.get("type") == "file": - if ext_set: - dot_pos = entry_path.rfind(".") - file_ext = entry_path[dot_pos + 1:].lower() if dot_pos != -1 else "" - if file_ext not in ext_set: - continue - all_files.append(entry_path) - - for root in paths: - await _walk(root, depth=0) - - if last_run_at is None: - return all_files - - # Filter by modification date. - last_run_ms = int(last_run_at.timestamp() * 1000) - filtered: list[str] = [] - for file_path in all_files: - try: - meta = await execute_on_client(action="get_file_metadata", data={"path": file_path}) - modified_at = meta.get("modifiedAt") - if modified_at is None: - filtered.append(file_path) - continue - if isinstance(modified_at, (int, float)): - mod_ms = int(modified_at) - else: - mod_ms = int(datetime.fromisoformat(str(modified_at)).timestamp() * 1000) - if mod_ms > last_run_ms: - filtered.append(file_path) - except Exception: - filtered.append(file_path) # fail-open - - return filtered - - -# ── Code-based entity fetchers ──────────────────────────────────────────── - - -async def _fetch_projects() -> list[dict]: - """Fetch all projects from the Electron client via WS.""" - try: - result = await execute_on_client(action="select", table="projects") - return result.get("rows", []) - except Exception as exc: - logger.warning("agent_runner: failed to fetch projects: %s", exc) - return [] - - -_DOMAIN_TABLE: dict[str, str] = { - "tasks": "tasks", - "notes": "notes", - "timelines": "timelines", - "projects": "projects", -} - - -async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]: - """Fetch existing rows for a domain, scoped to a project where applicable.""" - table = _DOMAIN_TABLE.get(domain) - if not table: - return [] - filters: dict[str, Any] = {} - if project_id != "standalone" and domain != "projects": - filters["projectId"] = project_id - try: - result = await execute_on_client( - action="select", - table=table, - filters=filters if filters else None, - ) - return result.get("rows", []) - except Exception as exc: - logger.warning("agent_runner: failed to fetch %s: %s", domain, exc) - return [] - - -def _format_entities_for_context(domain: str, rows: list[dict]) -> str: - """Format existing entity rows as a readable context block for the LLM. - - Includes enough detail per record for the LLM to make a confident - update-vs-create decision without overwhelming the context. - Note content is truncated to 200 chars to stay within token budget. - """ - if not rows: - return f"No existing {domain}." - lines: list[str] = [] - for r in rows: - if domain == "tasks": - desc = r.get("description") or "" - desc_part = f" — {desc[:120]}" if desc else "" - assignee = r.get("assignee") or r.get("assignees") or "" - due = r.get("dueDate") or r.get("due_date") or "" - meta = ", ".join(filter(None, [ - f"priority: {r.get('priority', '')}" if r.get("priority") else "", - f"assignee: {assignee}" if assignee else "", - f"due: {due}" if due else "", - ])) - lines.append( - f" - [{r.get('status', '?')}] {r.get('title', '')}{desc_part}" - f" ({meta}, id: {r['id']})" - ) - elif domain == "notes": - snippet = (r.get("content") or "")[:200].replace("\n", " ") - snippet_part = f"\n Preview: {snippet}" if snippet else "" - lines.append( - f" - {r.get('title', '')} (id: {r['id']}){snippet_part}" - ) - elif domain == "timelines": - lines.append( - f" - {r.get('title', '')} date={r.get('date', '')} (id: {r['id']})" - ) - elif domain == "projects": - summary = (r.get("aiSummary") or r.get("ai_summary") or "")[:120] - summary_part = f" — {summary}" if summary else "" - lines.append( - f" - {r.get('name', '')} [{r.get('status', '')}]{summary_part}" - f" (id: {r['id']})" - ) - return f"Existing {domain}:\n" + "\n".join(lines) - - -# ── Step 1: LLM file classifier ─────────────────────────────────────────── - - -async def _classify_file( - file_path: str, - file_content: str, - projects: list[dict], - config_data_types: list[str], -) -> tuple[str, list[str], str | None]: - """Call the LLM to classify a file by project and relevant domains. - - Returns ``(project_id_or_"new", domains, new_project_name_or_None)``. - - ``project_id`` is an existing project UUID, or ``"new"`` when no match found. - - ``new_project_name`` is only set when ``project_id == "new"``. - Falls back to ``("new", config_data_types, None)`` on any error. - """ - fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None) - - if not file_content.strip(): - return fallback - - valid_project_ids = {p["id"] for p in projects} - - def _fmt_project(p: dict) -> str: - summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip() - summary_part = f" — {summary[:100]}" if summary else "" - return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}" - - projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)" - - domain_definitions = "\n".join( - f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}" - for d in config_data_types - if d in _DOMAIN_DESCRIPTIONS - ) - - system = _STEP1_SYSTEM_PROMPT.format( - domain_definitions=domain_definitions, - projects_list=projects_list, - ) - - llm = get_llm() - try: - response = await llm.ainvoke([ - SystemMessage(content=system), - HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"), - ]) - raw = _as_text(response.content).strip() - # Strip markdown fences if the model wraps the JSON. - if raw.startswith("```"): - raw = raw.split("```")[1] - if raw.startswith("json"): - raw = raw[4:] - parsed = json.loads(raw.strip()) - raw_project_id: str = str(parsed.get("project_id") or "new") - # Reject hallucinated UUIDs — only accept ids that exist in the fetched list. - project_id = raw_project_id if raw_project_id in valid_project_ids else "new" - new_project_name: str | None = ( - str(parsed["new_project_name"]).strip() or None - if project_id == "new" and parsed.get("new_project_name") - else None - ) - domains: list[str] = [ - d for d in parsed.get("domains", []) - if d in config_data_types - ] - if not domains: - domains = list(config_data_types) - return project_id, domains, new_project_name - except Exception as exc: - logger.warning( - "agent_runner: step1 classification failed for %r: %s", file_path, exc - ) - return fallback - - -# ── Local agent runner (two-step per file) ──────────────────────────────── - - -async def run_local_agent( - user_id: str, - config: LocalAgentConfig, - run_log: AgentRunLog, - device_mgr: DeviceConnectionManager, - run_context: dict | None = None, -) -> None: - """Execute a local directory agent run using a two-step approach per file. - - Step 1 — Classification (code + 1 LLM call per file, no tools): - Code scans directories and fetches all projects via WS. - For each file, LLM identifies the project and relevant domains. - - Step 2 — Processing (code + 1 LLM call per file, with tools): - Code fetches existing entities for the identified project/domains. - LLM receives file content + existing entities in context and uses - tools to update existing records or create new ones. - """ - run_id = run_log.id - agent_id = (run_context or {}).get("agent_id") or config.id - _running_agents.add(agent_id) - - # ── Device online check ───────────────────────────────────────── - target_device_id = config.device_id.strip() if isinstance(config.device_id, str) else "" - is_online = ( - device_mgr.is_online(user_id, target_device_id) - if target_device_id - else device_mgr.is_online(user_id) - ) - - if not is_online: - logger.info( - "agent_runner: skip run=%s — device %r offline for user=%s", - run_id, - target_device_id or "", - user_id, - ) - await _finalize_run( - run_log, - status="error", - errors=[f"Device {target_device_id or ''!r} is not connected"], - ) - return - - # ── Set up WS executor for tools ──────────────────────────────── - executor = _make_agent_executor(user_id, device_mgr, run_context) - set_client_executor(executor) - - errors: list[str] = [] - items_processed = 0 - items_created = 0 - - custom_section = ( - f"User instructions:\n{config.prompt_template}" - if config.prompt_template - else "" - ) - - try: - # ── Code: scan directories ─────────────────────────────────── - logger.info("agent_runner: run=%s scanning directories user=%s", run_id, user_id) - file_paths = await _scan_directories( - paths=config.directory_paths, - extensions=config.file_extensions or [], - last_run_at=config.last_run_at, - ) - logger.info( - "agent_runner: run=%s found %d file(s) after filtering", run_id, len(file_paths) - ) - - if not file_paths: - await _finalize_run(run_log, status="success", items_processed=0, items_created=0) - return - - # ── Code: fetch all projects once ──────────────────────────── - projects = await _fetch_projects() - - for file_path in file_paths: - try: - # Read file content via code. - file_result = await execute_on_client( - action="read_file_content", data={"path": file_path} - ) - file_content: str = file_result.get("content", "") - if not file_content: - logger.debug("agent_runner: run=%s skipping empty file %r", run_id, file_path) - continue - - items_processed += 1 - - # Step 1 — classify file. - project_id, domains, new_project_name = await _classify_file( - file_path=file_path, - file_content=file_content, - projects=projects, - config_data_types=config.data_types, - ) - logger.info( - "agent_runner: run=%s file=%r → project=%s new_name=%r domains=%s", - run_id, - file_path, - project_id, - new_project_name, - domains, - ) - - # Step 2 — resolve project_id via CODE, then fetch entities. - # Project creation is NEVER delegated to the Step 2 LLM. - if project_id == "new": - proj_name = new_project_name or "Untitled Project" - try: - proj_result = await execute_on_client( - action="insert", - table="projects", - data={"name": proj_name, "clientId": None}, - ) - created = proj_result.get("row", {}) - effective_project_id = created.get("id", "standalone") - # Add to local list so subsequent files can match it. - if "id" in created: - projects.append(created) - logger.info( - "agent_runner: run=%s created project %r id=%s", - run_id, proj_name, effective_project_id, - ) - except Exception as exc: - logger.warning( - "agent_runner: run=%s failed to create project %r: %s", - run_id, proj_name, exc, - ) - effective_project_id = "standalone" - proj_name = "unknown" - project_context = ( - f"Project: {proj_name} (id: {effective_project_id}). " - "Always set projectId to this id on every record you create." - ) - else: - effective_project_id = project_id - proj = next((p for p in projects if p["id"] == project_id), None) - proj_name = proj.get("name", project_id) if proj else project_id - project_context = ( - f"Project: {proj_name} (id: {project_id}). " - "Always set projectId to this id on every record you create." - ) - - # "projects" domain is never passed to Step 2 — handled above in code. - domains = [d for d in domains if d != "projects"] - - existing_blocks: list[str] = [] - for domain in domains: - rows = await _fetch_domain_entities(domain, effective_project_id) - existing_blocks.append(_format_entities_for_context(domain, rows)) - - existing_context = "\n\n".join(existing_blocks) - - system_prompt = _PROCESSING_SYSTEM_PROMPT.format( - existing_context=existing_context, - project_context=project_context, - data_types=", ".join(domains), - custom_prompt_section=custom_section, - ) - - processing_tools = _build_processing_tools(domains) - - result_text = await _run_agent_with_tools( - system_prompt=system_prompt, - user_message=( - f"Process this file and extract relevant information.\n\n" - f"File: {file_path}\n\nContent:\n{file_content}" - ), - tools=processing_tools, - max_steps=_MAX_PROCESSING_STEPS, - ) - logger.info( - "agent_runner: run=%s file=%r result=%s", - run_id, - file_path, - result_text[:200], - ) - - except Exception as exc: - errors.append(f"Error processing '{file_path}': {exc}") - logger.error( - "agent_runner: run=%s file=%r failed: %s", run_id, file_path, exc - ) - - except Exception as exc: - errors.append(f"Agent run failed: {exc}") - logger.error("agent_runner: run=%s failed: %s", run_id, exc) - finally: - _running_agents.discard(agent_id) - clear_client_executor() - - # ── Finalise ──────────────────────────────────────────────────── - if errors and items_processed == 0: - final_status = "error" - elif errors: - final_status = "partial" - else: - final_status = "success" - - await _finalize_run( - run_log, - status=final_status, - items_processed=items_processed, - items_created=items_created, - errors=errors, - ) - logger.info( - "agent_runner: run=%s done status=%s processed=%d errors=%d", - run_id, - final_status, - items_processed, - len(errors), - ) - - # Notify Electron that the run is complete. - if run_context and device_mgr.is_online(user_id): - try: - await device_mgr.send_frame(user_id, { - "type": "run_complete", - "run_context": run_context, - "status": final_status, - }) - except Exception as exc: - logger.warning( - "agent_runner: run=%s failed to send run_complete: %s", run_id, exc - ) - - -# ── Cloud agent runner ───────────────────────────────────────────────────── - -_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7 - - -async def run_cloud_agent( - user_id: str, - config: CloudAgentConfig, - run_log: AgentRunLog, - device_mgr: DeviceConnectionManager, -) -> None: - """Execute a cloud connector agent run end-to-end. - - Steps: - - 1. Verify the user's device is online. - 2. Decrypt the stored OAuth token from ``config.oauth_token_encrypted``. - 3. Instantiate the provider client (Gmail or MS Graph). - 4. Fetch messages/emails since ``config.last_run_at`` (or 7 days ago for - the first run) applying ``config.filter_config`` filters. - 5. For each message/email call the LLM to extract structured items. - 6. Push each item to Electron as an ``insert`` tool-call. - 7. If the provider refreshed its access token, re-encrypt and write it - back to ``config.oauth_token_encrypted``. - 8. Persist the run outcome via ``_finalize_run``. - """ - run_id = run_log.id - - # ── 1. Device online check ───────────────────────────────────────── - if not device_mgr.is_online(user_id): - logger.info( - "agent_runner: skip cloud run=%s — no device online for user=%s", - run_id, - user_id, - ) - await _finalize_run( - run_log, - status="error", - errors=["No connected device — cloud agent results cannot be delivered"], - ) - return - - # ── 2. Decrypt OAuth token ───────────────────────────────────────── - from app.integrations import decrypt_token, encrypt_token, get_provider - - if not config.oauth_token_encrypted: - await _finalize_run( - run_log, - status="error", - errors=[f"No OAuth token stored for cloud agent '{config.name}'"], - ) - return - - try: - credentials_info = decrypt_token(config.oauth_token_encrypted) - except ValueError as exc: - logger.error("agent_runner: failed to decrypt OAuth token for agent %s: %s", config.id, exc) - await _finalize_run( - run_log, - status="error", - errors=[f"Failed to decrypt OAuth token: {exc}"], - ) - return - - # ── 3. Instantiate provider client ──────────────────────────────── - try: - provider = get_provider(config.provider, credentials_info) - except ValueError as exc: - await _finalize_run(run_log, status="error", errors=[str(exc)]) - return - - # ── 4. Fetch messages ───────────────────────────────────────────── - since: datetime | None = config.last_run_at - if since is None: - since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS) - if since.tzinfo is None: - since = since.replace(tzinfo=timezone.utc) - - errors: list[str] = [] - items_processed = 0 - items_created = 0 - - try: - if config.provider == "gmail": - raw_messages = await provider.fetch_messages( # type: ignore[union-attr] - filter_config=config.filter_config, - since=since, - ) - elif config.provider == "outlook": - raw_messages = await provider.fetch_emails( # type: ignore[union-attr] - filter_config=config.filter_config, - since=since, - ) - elif config.provider == "teams": - raw_messages = await provider.fetch_messages( # type: ignore[union-attr] - filter_config=config.filter_config, - since=since, - ) - else: - raw_messages = [] - except RuntimeError as exc: - logger.error( - "agent_runner: provider fetch failed for cloud agent %s: %s", config.id, exc - ) - await _finalize_run( - run_log, - status="error", - errors=[f"Provider fetch failed: {exc}"], - update_config_last_run=True, - config_id=config.id, - config_type="cloud", - ) - return - - logger.info( - "agent_runner: cloud agent %s fetched %d item(s) from %s for user=%s", - config.id, - len(raw_messages), - config.provider, - user_id, - ) - - # ── 5–6. Extract + insert via LLM with tools ───────────────────── - executor = _make_agent_executor(user_id, device_mgr) - set_client_executor(executor) - - try: - processing_tools = _build_processing_tools(config.data_types) - custom_section = ( - f"User instructions:\n{config.prompt_template}" - if config.prompt_template - else "" - ) - - for msg in raw_messages: - content_text = msg.as_text - if not content_text: - continue - items_processed += 1 - - processing_prompt = _CLOUD_PROCESSING_PROMPT.format( - data_types=", ".join(config.data_types), - project_context="Determine the appropriate project from the message context.", - file_list=f"Message from {config.provider} (id: {msg.id})", - custom_prompt_section=custom_section, - ) - - try: - await _run_agent_with_tools( - system_prompt=processing_prompt, - user_message=f"Process this message content:\n\n{content_text[:8000]}", - tools=processing_tools, - max_steps=_MAX_PROCESSING_STEPS, - ) - except Exception as exc: - errors.append(f"LLM processing error for message {msg.id!r}: {exc}") - finally: - clear_client_executor() - - # ── 7. Persist refreshed token (if any) ─────────────────────────── - refreshed = getattr(provider, "refreshed_credentials", None) - if refreshed: - try: - new_encrypted = encrypt_token(refreshed) - async with async_session() as db: - cfg_result = await db.execute( - select(CloudAgentConfig).where(CloudAgentConfig.id == config.id) - ) - cfg_row = cfg_result.scalar_one_or_none() - if cfg_row: - cfg_row.oauth_token_encrypted = new_encrypted - await db.commit() - logger.debug("agent_runner: refreshed OAuth token persisted for agent %s", config.id) - except Exception as exc: - logger.warning( - "agent_runner: failed to persist refreshed token for agent %s: %s", - config.id, - exc, - ) - - # ── 8. Finalise ──────────────────────────────────────────────────── - if errors and items_created == 0: - final_status = "error" - elif errors: - final_status = "partial" - else: - final_status = "success" - - await _finalize_run( - run_log, - status=final_status, - items_processed=items_processed, - items_created=items_created, - errors=errors, - update_config_last_run=True, - config_id=config.id, - config_type="cloud", - ) - logger.info( - "agent_runner: cloud run=%s done status=%s processed=%d created=%d errors=%d", - run_id, - final_status, - items_processed, - items_created, - len(errors), - ) - - -# ── Pending-run trigger ───────────────────────────────────────────────────── - - -async def trigger_pending_runs( - user_id: str, - device_id: str, - device_mgr: DeviceConnectionManager, -) -> None: - """Dispatch any overdue agent runs after an Electron device connects. - - Called as a background task from the device WS endpoint on ``device_hello``. - """ - logger.info( - "agent_runner: pending-run scan skipped for user=%s device=%s (client-owned agent config)", - user_id, - device_id, - ) - return - - -# ── Internal helper ───────────────────────────────────────────────────────── - - -async def _finalize_run( - run_log: AgentRunLog, - *, - status: str, - items_processed: int = 0, - items_created: int = 0, - errors: list[str] | None = None, - update_config_last_run: bool = False, - config_id: str | None = None, - config_type: str | None = None, -) -> None: - """Persist the run outcome and optionally update ``last_run_at`` on the config.""" - now = datetime.now(timezone.utc) - try: - async with async_session() as db: - managed = await db.merge(run_log) - managed.status = status - managed.items_processed = items_processed - managed.items_created = items_created - managed.errors = errors or [] - managed.completed_at = now - - if update_config_last_run and config_id: - if config_type == "local": - cfg_result = await db.execute( - select(LocalAgentConfig).where(LocalAgentConfig.id == config_id) - ) - cfg = cfg_result.scalar_one_or_none() - if cfg: - cfg.last_run_at = now - elif config_type == "cloud": - cfg_result = await db.execute( - select(CloudAgentConfig).where(CloudAgentConfig.id == config_id) - ) - cfg = cfg_result.scalar_one_or_none() - if cfg: - cfg.last_run_at = now - - await db.commit() - except Exception as exc: - logger.error( - "agent_runner: failed to finalize run_log=%s: %s", run_log.id, exc - ) diff --git a/app/core/deep_agent.py b/app/core/deep_agent.py deleted file mode 100644 index 0e490a5..0000000 --- a/app/core/deep_agent.py +++ /dev/null @@ -1,846 +0,0 @@ -"""Single-agent runners for home and floating chat contexts.""" - -from __future__ import annotations - -import json -import logging -import re -from datetime import date -from collections.abc import AsyncGenerator -from typing import Any, Literal - -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage -from langchain_core.tools import tool - -from app.agents.note_agent import NOTE_TOOLS -from app.agents.project_agent import PROJECT_TOOLS -from app.agents.task_agent import TASK_TOOLS -from app.agents.timeline_agent import TIMELINE_TOOLS -from app.core.llm import get_llm -from app.core.memory_middleware import MemoryMiddleware -from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector -from app.db import async_session - -logger = logging.getLogger(__name__) - -FloatingDomainType = Literal["task", "timeline", "project", "node"] -FloatingDomainSection = Literal["task", "timeline", "note"] - -_HOME_SINGLE_AGENT_SYSTEM = ( - "You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. " - "Always use tools for factual data retrieval before answering. " - "When the user asks to remember, forget, or update what you know about them, use memory tools. " - "If context.context.resolved_project_id exists, use it as project_id for scoped list calls. " - "Return markdown and use tags when relevant: [ids], [ids], " - "[ids], [ids], {json}. " - "When listing tasks or timelines, each id tag must be on its own line with no prefix/suffix text. " - "Never put titles, priorities, or dates on the same line as or tags. " - "For questions about upcoming timelines (e.g. 'prossimi eventi'), include only future items in the current month unless the user asks a different range. " - "For upcoming tasks, after tag lines add a short recommendation based on due date and priority." -) - -_FLOATING_SINGLE_AGENT_SYSTEM = ( - "You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. " - "Stay focused on the floating scope in context.scope and answer concisely. " - "Return plain text only. Do not output XML/HTML-like tags such as , , , , or any bracketed id tag wrappers. " - "Always use tools for factual data retrieval before answering. " - "When the user asks to remember, forget, or update what you know about them, use memory tools. " - "If context.context.resolved_project_id exists, use it as project_id for scoped list calls. " -) - -_FLOATING_DOMAIN_CLASSIFIER_SYSTEM = ( - "You are a strict domain classifier for websocket floating requests. " - "Return ONLY a JSON object with keys: type, id, section. " - "Allowed type values: task, timeline, project, node. " - "Allowed section values: task, timeline, note, or null. " - "Rules: infer from user message intent first; do not blindly trust scope.type. " - "If user asks tasks/timeline/notes for a project, set type=project and section accordingly. " - "If project id is unknown but context.resolved_project_id exists, use it as id. " - "If id is unknown, use null. " - "No markdown, no prose, JSON only." -) - - -def _as_text(content: Any) -> str: - if content is None: - return "" - if isinstance(content, str): - return content - if isinstance(content, list): - parts: list[str] = [] - for item in content: - if isinstance(item, str): - parts.append(item) - elif isinstance(item, dict): - text = item.get("text") - if isinstance(text, str): - parts.append(text) - return "".join(parts) - return str(content) - - -def _candidate_tokens(message: str) -> list[str]: - tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower()) - return [token for token in tokens if len(token) >= 3] - - -async def _resolve_project_id_from_message(message: str) -> str | None: - """Resolve likely project UUID from user message using client project list.""" - try: - result = await execute_on_client(action="select", table="projects") - except Exception as exc: - logger.warning("deep_agent: project resolve select failed: %s", exc) - return None - - rows = result.get("rows", []) - if not isinstance(rows, list) or not rows: - return None - - tokens = _candidate_tokens(message) - scored: list[tuple[int, dict[str, Any]]] = [] - for row in rows: - if not isinstance(row, dict): - continue - name = str(row.get("name", "")).lower() - score = sum(1 for token in tokens if token in name) - if score > 0: - scored.append((score, row)) - - if not scored: - return None - - scored.sort(key=lambda item: item[0], reverse=True) - top_score = scored[0][0] - top_rows = [row for score, row in scored if score == top_score] - if len(top_rows) != 1: - return None - - project_id = top_rows[0].get("id") - return project_id if isinstance(project_id, str) else None - - -def _needs_project_resolution(message: str) -> bool: - lowered = message.lower() - return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"]) - - -async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]: - prepared = dict(context) - if _needs_project_resolution(message): - resolved_project_id = await _resolve_project_id_from_message(message) - if resolved_project_id: - prepared["resolved_project_id"] = resolved_project_id - logger.info("deep_agent: resolved_project_id=%s", resolved_project_id) - return prepared - - -def _all_tools() -> list[Any]: - return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS] - - -def _trace_id_from_context(context: dict[str, Any]) -> str | None: - debug = context.get("_debug") - if isinstance(debug, dict): - request_id = debug.get("request_id") - if isinstance(request_id, str) and request_id: - return request_id - return None - - -def _context_for_model(context: dict[str, Any]) -> dict[str, Any]: - sanitized = dict(context) - sanitized.pop("_debug", None) - return sanitized - - -_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]") -_TIMELINE_DMY_RE = re.compile(r"(?P\d{2})/(?P\d{2})/(?P\d{4})") - - -def _is_upcoming_timeline_query(message: str) -> bool: - lowered = message.lower() - has_upcoming = "prossim" in lowered or "upcoming" in lowered or "next" in lowered - has_timeline_topic = any( - token in lowered - for token in ("event", "evento", "eventi", "timeline", "milestone", "scaden") - ) - return has_upcoming and has_timeline_topic - - -def _timeline_date_in_current_month_or_future(dmy: str) -> bool: - match = _TIMELINE_DMY_RE.search(dmy) - if not match: - return True - try: - parsed = date( - int(match.group("y")), - int(match.group("m")), - int(match.group("d")), - ) - except ValueError: - return True - - today = date.today() - return parsed >= today and parsed.year == today.year and parsed.month == today.month - - -def _normalize_tagged_list_lines(text: str, message: str) -> str: - if not text: - return text - - upcoming_timeline_only = _is_upcoming_timeline_query(message) - output_lines: list[str] = [] - - for line in text.splitlines(): - matches = list(_TAG_LINE_RE.finditer(line)) - if not matches: - output_lines.append(line) - continue - - had_non_tag_text = _TAG_LINE_RE.sub("", line).strip(" -\t0123456789.*:)") - if not had_non_tag_text and len(matches) == 1: - tag_text = matches[0].group(0) - if ( - upcoming_timeline_only - and "" in tag_text - and not _timeline_date_in_current_month_or_future(line) - ): - continue - output_lines.append(tag_text) - continue - - for match in matches: - tag_text = match.group(0) - if ( - upcoming_timeline_only - and "" in tag_text - and not _timeline_date_in_current_month_or_future(line) - ): - continue - output_lines.append(tag_text) - - return "\n".join(output_lines) - - -_GENERIC_TAG_RE = re.compile(r"", re.IGNORECASE) -_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]") -_FLOATING_EMPTY_FALLBACK = "No results found." - - -def _strip_floating_markup_fragment(text: str) -> str: - if not text: - return text - cleaned = _GENERIC_TAG_RE.sub("", text) - return _BRACKETED_ID_RE.sub("", cleaned) - - -def _strip_floating_markup(text: str) -> str: - """Ensure floating responses stay plain text with no XML-like tag wrappers.""" - if not text: - return text - - cleaned = _strip_floating_markup_fragment(text) - # Collapse excessive spaces introduced by tag/id removal while preserving lines. - lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()] - return "\n".join(line for line in lines if line) - - -def _fallback_from_raw_floating_text(raw_text: str) -> str: - fallback = _strip_floating_markup_fragment(raw_text or "") - fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip() - return fallback or _FLOATING_EMPTY_FALLBACK - - -class _FloatingStreamSanitizer: - """Streaming sanitizer that removes floating markup without buffering the full answer.""" - - def __init__(self) -> None: - self._pending = "" - - @staticmethod - def _split_safe_boundary(text: str) -> tuple[str, str]: - boundary = len(text) - - last_lt = text.rfind("<") - if last_lt != -1 and ">" not in text[last_lt:]: - boundary = min(boundary, last_lt) - - last_lb = text.rfind("[") - if last_lb != -1 and "]" not in text[last_lb:]: - boundary = min(boundary, last_lb) - - if boundary == len(text): - return text, "" - return text[:boundary], text[boundary:] - - def feed(self, chunk: str) -> str: - combined = f"{self._pending}{chunk}" - safe_text, self._pending = self._split_safe_boundary(combined) - return _strip_floating_markup_fragment(safe_text) - - def finalize(self) -> str: - # Drop dangling unfinished wrappers at the very end. - tail = re.sub(r"<[^>\n]*$", "", self._pending) - tail = re.sub(r"\[[^\]\n]*$", "", tail) - self._pending = "" - return _strip_floating_markup_fragment(tail) - - -def _normalize_memory_label(path_or_label: str) -> str: - value = path_or_label.strip() - if value.startswith("/memories/"): - value = value[len("/memories/"):] - value = value.strip("/") - return value - - -def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]: - @tool - async def memory_list_blocks() -> str: - """List all core memory blocks currently stored for the user.""" - logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id) - async with async_session() as db: - memory = MemoryMiddleware(db) - blocks = await memory.list_core_blocks(user_id) - if not blocks: - return "No memory blocks found." - lines = [f"- {b['label']}: {b['value']}" for b in blocks] - return "Memory blocks:\n" + "\n".join(lines) - - @tool - async def memory_get(path_or_label: str) -> str: - """Get one memory block by label or /memories/