refactor: remove monolith app/, Dockerfile, requirements.txt
All business logic has been extracted into microservices: - services/auth/ (Step 1) - services/ws-gateway/ (Step 2) - services/chat/ (Step 2) - services/batch-agent/ (Step 3) - services/billing/ (Step 4) Shared code lives in shared/. Migrations remain in alembic/. Tests in tests/ will need updating to target individual services.
This commit is contained in:
39
Dockerfile
39
Dockerfile
@@ -1,39 +0,0 @@
|
||||
# ── builder ──────────────────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --upgrade pip && \
|
||||
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||
|
||||
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim AS runtime
|
||||
|
||||
# Non-root user
|
||||
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy installed packages from builder
|
||||
COPY --from=builder /install /usr/local
|
||||
|
||||
# Copy application source
|
||||
COPY app/ app/
|
||||
|
||||
# Copy Alembic migration files
|
||||
COPY alembic/ alembic/
|
||||
COPY alembic.ini .
|
||||
|
||||
# Ensure appuser owns the working directory
|
||||
RUN chown -R appuser:appgroup /app
|
||||
|
||||
USER appuser
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["gunicorn", "app.main:app", \
|
||||
"-k", "uvicorn.workers.UvicornWorker", \
|
||||
"--bind", "0.0.0.0:8000", \
|
||||
"--workers", "4", \
|
||||
"--timeout", "120"]
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Expose tool modules used by deep orchestrator-worker graphs."""
|
||||
|
||||
from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent
|
||||
|
||||
__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"]
|
||||
@@ -1,85 +0,0 @@
|
||||
"""Filesystem agent — tools for reading local directories and files on Electron.
|
||||
|
||||
These tools delegate to the Electron client via ``execute_on_client()`` using
|
||||
the same WS tool-call round-trip pattern as CRUD tools. The Electron app
|
||||
handles actual disk I/O and responds with ``tool_result`` frames.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
|
||||
@tool
|
||||
async def list_directory(path: str) -> str:
|
||||
"""List files and folders in a local directory on the user's device.
|
||||
|
||||
Returns a formatted listing of entries with name, type (file/directory),
|
||||
and full path.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="list_directory",
|
||||
data={"path": path},
|
||||
)
|
||||
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||
if not entries:
|
||||
return f"Directory '{path}' is empty or does not exist."
|
||||
lines: list[str] = []
|
||||
for entry in entries:
|
||||
entry_type = entry.get("type", "unknown")
|
||||
entry_name = entry.get("name", "")
|
||||
entry_path = entry.get("path", "")
|
||||
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def read_file_content(path: str) -> str:
|
||||
"""Read the text content of a local file on the user's device.
|
||||
|
||||
Returns the file content as a string. Large files may be truncated
|
||||
by the Electron client.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="read_file_content",
|
||||
data={"path": path},
|
||||
)
|
||||
content: str = result.get("content", "")
|
||||
if not content:
|
||||
return f"File '{path}' is empty or could not be read."
|
||||
return content
|
||||
|
||||
|
||||
@tool
|
||||
async def get_file_metadata(path: str) -> str:
|
||||
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||
|
||||
Returns a formatted summary of the file's metadata.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="get_file_metadata",
|
||||
data={"path": path},
|
||||
)
|
||||
size = result.get("size", "unknown")
|
||||
created = result.get("createdAt", "unknown")
|
||||
modified = result.get("modifiedAt", "unknown")
|
||||
extension = result.get("extension", "unknown")
|
||||
name = result.get("name", path)
|
||||
return (
|
||||
f"File: {name}\n"
|
||||
f" Extension: {extension}\n"
|
||||
f" Size: {size} bytes\n"
|
||||
f" Created: {created}\n"
|
||||
f" Modified: {modified}"
|
||||
)
|
||||
|
||||
|
||||
FILESYSTEM_TOOLS: list[Any] = [
|
||||
list_directory,
|
||||
read_file_content,
|
||||
get_file_metadata,
|
||||
]
|
||||
@@ -1,139 +0,0 @@
|
||||
"""Note agent — Markdown note management (list, get, create, update, delete)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.llm import embed
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(value: str) -> bool:
|
||||
return bool(_UUID_RE.match(value))
|
||||
|
||||
NOTE_SYSTEM_PROMPT = (
|
||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
||||
"and delete Markdown notes in their workspace.\n\n"
|
||||
"Rules:\n"
|
||||
" - content is always Markdown; preserve formatting when updating\n"
|
||||
" - project_id is optional; link a note to a project when mentioned\n"
|
||||
" - When updating, call get_note first if you need to read existing content\n"
|
||||
" before appending or replacing sections\n"
|
||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
||||
" when the user is working within a specific project\n"
|
||||
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
|
||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
||||
" is already in the note (retrieved via get_note)."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_notes(project_id: str = "") -> str:
|
||||
"""List notes, optionally scoped to a project by project_id."""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="notes",
|
||||
filters={"projectId": normalized_project_id or None},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No notes found."
|
||||
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def get_note(note_id: str) -> str:
|
||||
"""Fetch a single note by its UUID to read its full Markdown content."""
|
||||
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
|
||||
row = result.get("row")
|
||||
if not row:
|
||||
return f"Note {note_id} not found."
|
||||
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
|
||||
|
||||
|
||||
@tool
|
||||
async def create_note(
|
||||
title: str,
|
||||
content: str,
|
||||
project_id: str = "",
|
||||
) -> str:
|
||||
"""Create a new note.
|
||||
title: note heading (required)
|
||||
content: Markdown body text (required)
|
||||
project_id: optional UUID linking this note to a project
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="notes",
|
||||
data={
|
||||
"title": title,
|
||||
"content": content,
|
||||
"projectId": project_id or None,
|
||||
},
|
||||
)
|
||||
row = result["row"]
|
||||
# Index the note content in the vector store.
|
||||
vector = await embed(content)
|
||||
await execute_on_client(
|
||||
action="vector_upsert",
|
||||
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
|
||||
vector=vector,
|
||||
)
|
||||
return f"Note created: '{row['title']}' (id: {row['id']})."
|
||||
|
||||
|
||||
@tool
|
||||
async def update_note(
|
||||
note_id: str,
|
||||
title: str = "",
|
||||
content: str = "",
|
||||
) -> str:
|
||||
"""Update an existing note. Only pass fields that should change.
|
||||
note_id: UUID of the note (required)
|
||||
If you need to preserve existing content, call get_note first.
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if content:
|
||||
updates["content"] = content
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="notes",
|
||||
data={"id": note_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
# Re-index if content changed.
|
||||
if content:
|
||||
vector = await embed(content)
|
||||
await execute_on_client(
|
||||
action="vector_upsert",
|
||||
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
|
||||
vector=vector,
|
||||
)
|
||||
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_note(note_id: str) -> str:
|
||||
"""Delete a note permanently by its UUID."""
|
||||
await execute_on_client(action="delete", table="notes", data={"id": note_id})
|
||||
return f"Note {note_id} deleted."
|
||||
|
||||
|
||||
NOTE_TOOLS: list[Any] = [
|
||||
list_notes,
|
||||
get_note,
|
||||
create_note,
|
||||
update_note,
|
||||
delete_note,
|
||||
]
|
||||
@@ -1,143 +0,0 @@
|
||||
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
PROJECT_SYSTEM_PROMPT = (
|
||||
"You are a project management assistant. You help users create, find,\n"
|
||||
"update, and archive projects in their workspace.\n\n"
|
||||
"Rules:\n"
|
||||
" - status must be one of: active, archived\n"
|
||||
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
||||
" - ai_summary is populated only when the user asks for a project summary;\n"
|
||||
" derive it from context data — do not fabricate content\n"
|
||||
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
||||
" user wants a complete cross-client view including archived projects\n"
|
||||
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
||||
" list_projects if you only have a project name\n"
|
||||
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
||||
" only call delete_project when the user explicitly confirms deletion."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_projects(
|
||||
client_id: str = "",
|
||||
include_archived: int = 0,
|
||||
) -> str:
|
||||
"""List projects, optionally filtered by client_id.
|
||||
include_archived: 1 to include archived projects, 0 for active only (default).
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="projects",
|
||||
filters={
|
||||
"clientId": client_id or None,
|
||||
"includeArchived": bool(include_archived),
|
||||
},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No projects found."
|
||||
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} project(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_all_projects() -> str:
|
||||
"""List every project regardless of client or status.
|
||||
Use only when the user wants a complete cross-client overview.
|
||||
"""
|
||||
result = await execute_on_client(action="select", table="projects")
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No projects found."
|
||||
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||
return f"All projects ({len(rows)}):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def get_project(project_id: str) -> str:
|
||||
"""Fetch a single project by its UUID."""
|
||||
result = await execute_on_client(action="get", table="projects", data={"id": project_id})
|
||||
row = result.get("row")
|
||||
if not row:
|
||||
return f"Project {project_id} not found."
|
||||
return (
|
||||
f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, "
|
||||
f"clientId: {row.get('clientId', 'none')})"
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def create_project(
|
||||
name: str,
|
||||
client_id: str = "",
|
||||
) -> str:
|
||||
"""Create a new project.
|
||||
name: human-readable project name (required)
|
||||
client_id: optional UUID of the owning client
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="projects",
|
||||
data={"name": name, "clientId": client_id or None},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Project created: '{row['name']}' (id: {row['id']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def update_project(
|
||||
project_id: str,
|
||||
name: str = "",
|
||||
client_id: str = "",
|
||||
status: str = "",
|
||||
ai_summary: str = "",
|
||||
) -> str:
|
||||
"""Update a project. Only pass fields that should change.
|
||||
project_id: UUID of the project (required)
|
||||
status: active | archived
|
||||
ai_summary: AI-generated summary text (populate only when explicitly requested)
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if name:
|
||||
updates["name"] = name
|
||||
if client_id:
|
||||
updates["clientId"] = client_id
|
||||
if status:
|
||||
updates["status"] = status
|
||||
if ai_summary:
|
||||
updates["aiSummary"] = ai_summary
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="projects",
|
||||
data={"id": project_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_project(project_id: str) -> str:
|
||||
"""Permanently delete a project and orphan its tasks.
|
||||
IMPORTANT: prefer update_project(status='archived') unless the user
|
||||
has explicitly confirmed they want permanent deletion.
|
||||
"""
|
||||
await execute_on_client(action="delete", table="projects", data={"id": project_id})
|
||||
return f"Project {project_id} permanently deleted."
|
||||
|
||||
|
||||
PROJECT_TOOLS: list[Any] = [
|
||||
list_projects,
|
||||
list_all_projects,
|
||||
get_project,
|
||||
create_project,
|
||||
update_project,
|
||||
delete_project,
|
||||
]
|
||||
@@ -1,238 +0,0 @@
|
||||
"""Task agent — full CRUD for tasks and task comments."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(value: str) -> bool:
|
||||
return bool(_UUID_RE.match(value))
|
||||
|
||||
TASK_SYSTEM_PROMPT = (
|
||||
"You are a task management assistant for a project workspace.\n"
|
||||
"You create, update, list, and track tasks and their comments.\n\n"
|
||||
"Rules:\n"
|
||||
" - status must be one of: todo, in_progress, done\n"
|
||||
" - priority must be one of: high, medium, low\n"
|
||||
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
||||
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
||||
" - project_id is optional; link to a project when the user mentions one\n"
|
||||
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
||||
" did not explicitly request; 0 otherwise\n"
|
||||
" - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n"
|
||||
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
||||
" - For update_task, use -1 for integer fields you do not want to change\n"
|
||||
" - Always confirm the action in plain, user-friendly language."
|
||||
)
|
||||
|
||||
|
||||
# ── Task tools ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@tool
|
||||
async def list_tasks(
|
||||
project_id: str = "",
|
||||
status: str = "",
|
||||
search: str = "",
|
||||
order_by: str = "",
|
||||
) -> str:
|
||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="tasks",
|
||||
filters={
|
||||
"projectId": normalized_project_id or None,
|
||||
"status": status or None,
|
||||
"search": search or None,
|
||||
"orderBy": order_by or None,
|
||||
},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No tasks found matching the given filters."
|
||||
lines = [
|
||||
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def create_task(
|
||||
title: str,
|
||||
description: str = "",
|
||||
status: str = "todo",
|
||||
priority: str = "medium",
|
||||
assignees: str = "[]",
|
||||
due_date: int = 0,
|
||||
project_id: str = "",
|
||||
is_ai_suggested: int = 0,
|
||||
) -> str:
|
||||
"""Create a new task.
|
||||
title: task title (required)
|
||||
description: optional details
|
||||
status: todo | in_progress | done (default: todo)
|
||||
priority: high | medium | low (default: medium)
|
||||
assignees: JSON-encoded array of assignee names, e.g. '["Alice"]'
|
||||
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||
project_id: optional UUID of the parent project
|
||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="tasks",
|
||||
data={
|
||||
"title": title,
|
||||
"description": description or None,
|
||||
"status": status,
|
||||
"priority": priority,
|
||||
"assignee": assignees,
|
||||
"dueDate": due_date or None,
|
||||
"projectId": project_id or None,
|
||||
"isAiSuggested": is_ai_suggested,
|
||||
},
|
||||
)
|
||||
row = result["row"]
|
||||
return (
|
||||
f"Task created: '{row['title']}' "
|
||||
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def update_task(
|
||||
task_id: str,
|
||||
title: str = "",
|
||||
description: str = "",
|
||||
status: str = "",
|
||||
priority: str = "",
|
||||
assignees: str = "",
|
||||
due_date: int = -1,
|
||||
project_id: str = "",
|
||||
) -> str:
|
||||
"""Update fields on an existing task. Only pass fields you want to change.
|
||||
task_id: the task's UUID (required)
|
||||
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if description:
|
||||
updates["description"] = description
|
||||
if status:
|
||||
updates["status"] = status
|
||||
if priority:
|
||||
updates["priority"] = priority
|
||||
if assignees:
|
||||
updates["assignee"] = assignees
|
||||
if due_date != -1:
|
||||
updates["dueDate"] = due_date or None
|
||||
if project_id:
|
||||
updates["projectId"] = project_id
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="tasks",
|
||||
data={"id": task_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_task(task_id: str) -> str:
|
||||
"""Delete a task permanently by its UUID."""
|
||||
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
|
||||
return f"Task {task_id} deleted."
|
||||
|
||||
|
||||
@tool
|
||||
async def list_tasks_due_today() -> str:
|
||||
"""List all tasks whose due date falls on today's date."""
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
|
||||
end_ms = start_ms + 86_400_000 - 1 # last ms of today
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="tasks",
|
||||
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No tasks are due today."
|
||||
lines = [
|
||||
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
# ── Task comment tools ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@tool
|
||||
async def list_task_comments(task_id: str) -> str:
|
||||
"""List all comments on a task by its UUID."""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="taskComments",
|
||||
filters={"taskId": task_id},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return f"No comments found for task {task_id}."
|
||||
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
||||
"""Add a comment to a task.
|
||||
task_id: UUID of the task to comment on
|
||||
author: name or ID of the comment author
|
||||
content: comment text
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="taskComments",
|
||||
data={"taskId": task_id, "author": author, "content": content},
|
||||
)
|
||||
row = result.get("row", {})
|
||||
row_author = row.get("author", author)
|
||||
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
|
||||
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
||||
row_comment_id = row.get("id", "unknown")
|
||||
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_task_comment(comment_id: str) -> str:
|
||||
"""Delete a task comment by its UUID."""
|
||||
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
|
||||
return f"Comment {comment_id} deleted."
|
||||
|
||||
|
||||
# ── Agent ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
TASK_TOOLS: list[Any] = [
|
||||
list_tasks,
|
||||
create_task,
|
||||
update_task,
|
||||
delete_task,
|
||||
list_tasks_due_today,
|
||||
list_task_comments,
|
||||
add_task_comment,
|
||||
delete_task_comment,
|
||||
]
|
||||
@@ -1,114 +0,0 @@
|
||||
"""Timeline agent — project milestone management (list, create, update, delete)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(value: str) -> bool:
|
||||
return bool(_UUID_RE.match(value))
|
||||
|
||||
TIMELINE_SYSTEM_PROMPT = (
|
||||
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
||||
"track progress on a project — they are not calendar events.\n\n"
|
||||
"Rules:\n"
|
||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
||||
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
|
||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
||||
" - For update_timeline, use -1 for integer fields you do not want to change\n"
|
||||
" - Listing without a project_id returns all timelines across projects\n"
|
||||
" - Always echo the title and formatted date in your confirmation."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_timelines(project_id: str = "") -> str:
|
||||
"""List timelines. Provide project_id to scope to a specific project."""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="timelines",
|
||||
filters={"projectId": normalized_project_id or None},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No timelines found."
|
||||
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def create_timeline(
|
||||
project_id: str,
|
||||
title: str,
|
||||
date: int,
|
||||
is_ai_suggested: int = 0,
|
||||
) -> str:
|
||||
"""Create a project timeline (milestone).
|
||||
project_id: REQUIRED UUID of the parent project
|
||||
title: descriptive name for the milestone
|
||||
date: Unix timestamp in milliseconds
|
||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="timelines",
|
||||
data={
|
||||
"projectId": project_id,
|
||||
"title": title,
|
||||
"date": date,
|
||||
"isAiSuggested": is_ai_suggested,
|
||||
},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def update_timeline(
|
||||
timeline_id: str,
|
||||
title: str = "",
|
||||
date: int = -1,
|
||||
) -> str:
|
||||
"""Update a timeline. Only pass fields that should change.
|
||||
timeline_id: UUID of the timeline (required)
|
||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if date != -1:
|
||||
updates["date"] = date
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="timelines",
|
||||
data={"id": timeline_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_timeline(timeline_id: str) -> str:
|
||||
"""Delete a timeline permanently by its UUID."""
|
||||
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
|
||||
return f"Timeline {timeline_id} deleted."
|
||||
|
||||
|
||||
TIMELINE_TOOLS: list[Any] = [
|
||||
list_timelines,
|
||||
create_timeline,
|
||||
update_timeline,
|
||||
delete_timeline,
|
||||
]
|
||||
@@ -1,14 +0,0 @@
|
||||
"""Shared FastAPI dependencies.
|
||||
|
||||
``get_current_user`` and ``oauth2_scheme`` live in ``app.api.middleware.auth``
|
||||
(the canonical location per Step 9). This module re-exports them so that all
|
||||
existing route imports (``from app.api.deps import get_current_user``) continue
|
||||
to work without modification.
|
||||
|
||||
Step 12 will update ``get_current_user`` to fetch the live tier from PostgreSQL
|
||||
instead of reading it from the JWT payload.
|
||||
"""
|
||||
|
||||
from app.api.middleware.auth import get_current_user, oauth2_scheme # noqa: F401
|
||||
|
||||
__all__ = ["get_current_user", "oauth2_scheme"]
|
||||
@@ -1,19 +0,0 @@
|
||||
"""API middleware package.
|
||||
|
||||
Exports the three middleware components introduced in Step 9:
|
||||
- Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme``
|
||||
- Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter)
|
||||
- Sanitizer: ``SanitizerMiddleware``
|
||||
"""
|
||||
|
||||
from app.api.middleware.auth import get_current_user, oauth2_scheme
|
||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter
|
||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||
|
||||
__all__ = [
|
||||
"get_current_user",
|
||||
"oauth2_scheme",
|
||||
"TierRateLimitMiddleware",
|
||||
"limiter",
|
||||
"SanitizerMiddleware",
|
||||
]
|
||||
@@ -1,80 +0,0 @@
|
||||
"""Auth middleware — JWT validation dependency.
|
||||
|
||||
``get_current_user`` is the FastAPI dependency used by all protected routes.
|
||||
It decodes the Bearer JWT (identity + expiry), then fetches the current tier
|
||||
from the ``subscriptions`` table so that tier changes take effect immediately
|
||||
without requiring token re-issue.
|
||||
|
||||
Exempt routes (no JWT required):
|
||||
- POST /api/v1/auth/register
|
||||
- POST /api/v1/auth/login
|
||||
- POST /api/v1/billing/webhook
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.db import get_session
|
||||
from app.schemas import UserProfile
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Validate a Bearer JWT and return the authenticated user.
|
||||
|
||||
The JWT is used for identity and expiry only. The tier is fetched live
|
||||
from the ``subscriptions`` table so that upgrades/downgrades take effect
|
||||
immediately. Falls back to ``'free'`` when no subscription row exists.
|
||||
|
||||
Raises HTTP 401 on any invalid or expired token.
|
||||
"""
|
||||
credentials_exc = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
email: str | None = payload.get("email")
|
||||
if not user_id or not email:
|
||||
raise credentials_exc
|
||||
except JWTError:
|
||||
raise credentials_exc
|
||||
|
||||
# Live tier lookup — subscription row is the authoritative source.
|
||||
# In dev, fall back to 'power' (unlimited) so quota limits don't
|
||||
# block local development when no Stripe subscription exists.
|
||||
from app.models import Subscription, User # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||
tier: str = result.scalar_one_or_none() or default_tier
|
||||
|
||||
# Fetch name/surname from user row.
|
||||
user_result = await db.execute(
|
||||
select(User.name, User.surname).where(User.id == user_id)
|
||||
)
|
||||
user_row = user_result.one_or_none()
|
||||
|
||||
return UserProfile(
|
||||
id=user_id,
|
||||
email=email,
|
||||
name=user_row.name if user_row else None,
|
||||
surname=user_row.surname if user_row else None,
|
||||
tier=tier,
|
||||
) # type: ignore[arg-type]
|
||||
@@ -1,129 +0,0 @@
|
||||
"""Tier-aware rate limiting middleware.
|
||||
|
||||
Uses a per-user sliding-window counter (in-process, no Redis required).
|
||||
The ``slowapi`` Limiter is also exported for optional route-level decoration.
|
||||
|
||||
Limits (requests per minute):
|
||||
- free: 20
|
||||
- pro: 60
|
||||
- power: 120
|
||||
- team: 200
|
||||
|
||||
Exempt paths bypass the limiter entirely:
|
||||
- POST /api/v1/auth/register
|
||||
- POST /api/v1/auth/login
|
||||
- POST /api/v1/billing/webhook
|
||||
- GET /api/v1/health
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import Request, Response
|
||||
from jose import JWTError, jwt
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
_TIER_LIMITS: dict[str, int] = {
|
||||
"free": 20,
|
||||
"pro": 60,
|
||||
"power": 120,
|
||||
"team": 200,
|
||||
}
|
||||
|
||||
_EXEMPT_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/billing/webhook",
|
||||
"/api/v1/health",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_user_id_from_jwt(request: Request) -> str:
|
||||
"""Key function for the slowapi Limiter: returns JWT sub or remote IP."""
|
||||
auth = request.headers.get("Authorization", "")
|
||||
token = auth.removeprefix("Bearer ").strip()
|
||||
if not token:
|
||||
return get_remote_address(request)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
return payload.get("sub") or get_remote_address(request)
|
||||
except JWTError:
|
||||
return get_remote_address(request)
|
||||
|
||||
|
||||
# Exported Limiter instance — available for optional route-level decoration.
|
||||
limiter = Limiter(key_func=_get_user_id_from_jwt)
|
||||
|
||||
|
||||
class TierRateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Sliding-window rate limiter applied globally across all non-exempt routes.
|
||||
|
||||
Each authenticated user gets their own 60-second window sized by tier.
|
||||
Unauthenticated requests pass through (the auth dependency will reject them
|
||||
with 401 before the route handler runs).
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
# user_id → list of request timestamps (float, seconds since epoch)
|
||||
self._window: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
||||
if request.url.path in _EXEMPT_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
# Extract JWT claims — if no valid token, pass through for auth dep to handle.
|
||||
auth = request.headers.get("Authorization", "")
|
||||
token = auth.removeprefix("Bearer ").strip()
|
||||
if not token:
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
user_id: str = payload.get("sub") or get_remote_address(request)
|
||||
tier: str = payload.get("tier", "free")
|
||||
except JWTError:
|
||||
return await call_next(request)
|
||||
|
||||
limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"])
|
||||
now = time.monotonic()
|
||||
window_start = now - 60.0
|
||||
|
||||
# Slide the window: discard timestamps older than 60 seconds.
|
||||
timestamps = [t for t in self._window[user_id] if t > window_start]
|
||||
|
||||
if len(timestamps) >= limit:
|
||||
retry_after = max(1, int(60 - (now - min(timestamps))))
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"detail": (
|
||||
f"Rate limit exceeded ({limit} req/min for {tier} tier). "
|
||||
f"Retry in {retry_after}s."
|
||||
)
|
||||
}
|
||||
),
|
||||
status_code=429,
|
||||
headers={
|
||||
"Retry-After": str(retry_after),
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
timestamps.append(now)
|
||||
self._window[user_id] = timestamps
|
||||
return await call_next(request)
|
||||
@@ -1,139 +0,0 @@
|
||||
"""Response sanitizer middleware.
|
||||
|
||||
Scans JSON responses from the /api/v1/chat endpoint and strips any fragments
|
||||
that could reveal server-side prompt IP:
|
||||
- System prompt openers ("You are a/an/the …")
|
||||
- Agent routing metadata ("Available agents:", "intent classifier", …)
|
||||
- LangChain tool schema fragments (``"type": "function"``)
|
||||
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
||||
- Exact-match known prompt fingerprints
|
||||
|
||||
Binary responses (storage blobs, backup data) are never touched — the
|
||||
middleware only activates for paths under /api/v1/chat.
|
||||
|
||||
Any sanitisation event is logged as a WARNING with the request path and the
|
||||
names of the fields that were modified.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detection patterns — order matters: fingerprints checked first (exact),
|
||||
# then compiled regexes.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FINGERPRINTS: tuple[str, ...] = (
|
||||
"You are an intent classifier",
|
||||
"Respond with just the agent name",
|
||||
"Summarize these agent results",
|
||||
"Available agents:",
|
||||
"route to:",
|
||||
)
|
||||
|
||||
_PATTERNS: tuple[re.Pattern[str], ...] = (
|
||||
re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"Available agents\s*:", re.IGNORECASE),
|
||||
re.compile(r"\bintent classifier\b", re.IGNORECASE),
|
||||
re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema
|
||||
re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE),
|
||||
re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers
|
||||
re.compile(r"route\s+to\s*:", re.IGNORECASE),
|
||||
re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE),
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_text(text: str) -> tuple[str, bool]:
|
||||
"""Scan *text* for prompt fragments and replace matches with ``[REDACTED]``.
|
||||
|
||||
Returns ``(cleaned_text, was_changed)``.
|
||||
"""
|
||||
# Fingerprint check — if any exact phrase is present, redact the whole string.
|
||||
for fp in _FINGERPRINTS:
|
||||
if fp in text:
|
||||
return "[REDACTED]", True
|
||||
|
||||
changed = False
|
||||
for pattern in _PATTERNS:
|
||||
new_text, n = pattern.subn("[REDACTED]", text)
|
||||
if n:
|
||||
text = new_text
|
||||
changed = True
|
||||
|
||||
return text, changed
|
||||
|
||||
|
||||
class SanitizerMiddleware(BaseHTTPMiddleware):
|
||||
"""Strip prompt IP from /api/v1/chat JSON responses."""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
||||
response: Response = await call_next(request)
|
||||
|
||||
# Only process chat endpoint responses.
|
||||
if not request.url.path.startswith("/api/v1/chat"):
|
||||
return response
|
||||
|
||||
# Read body — collect streaming chunks.
|
||||
body_bytes = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode()
|
||||
|
||||
# Skip non-JSON bodies (shouldn't happen on /chat, but be safe).
|
||||
try:
|
||||
body = json.loads(body_bytes.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
return Response(
|
||||
content=body_bytes,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
if not isinstance(body, dict):
|
||||
return Response(
|
||||
content=body_bytes,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
# Walk top-level string fields and sanitise.
|
||||
sanitised_fields: list[str] = []
|
||||
for key, value in body.items():
|
||||
if isinstance(value, str):
|
||||
cleaned, changed = _sanitize_text(value)
|
||||
if changed:
|
||||
body[key] = cleaned
|
||||
sanitised_fields.append(key)
|
||||
|
||||
if sanitised_fields:
|
||||
logger.warning(
|
||||
"Sanitizer redacted prompt fragments",
|
||||
extra={
|
||||
"path": request.url.path,
|
||||
"fields": sanitised_fields,
|
||||
},
|
||||
)
|
||||
|
||||
new_body = json.dumps(body).encode("utf-8")
|
||||
headers = dict(response.headers)
|
||||
headers["content-length"] = str(len(new_body))
|
||||
|
||||
return Response(
|
||||
content=new_body,
|
||||
status_code=response.status_code,
|
||||
headers=headers,
|
||||
media_type="application/json",
|
||||
)
|
||||
@@ -1,406 +0,0 @@
|
||||
"""Chatbot Journey — WS-based guided conversation to build an agent prompt_template.
|
||||
|
||||
The journey is driven entirely through WebSocket frames (no REST endpoints).
|
||||
The device WS handler dispatches ``journey_start`` and ``journey_message``
|
||||
frames to the functions exported here.
|
||||
|
||||
Journey flow:
|
||||
1. FE sends ``journey_start`` frame with basic agent config (directory,
|
||||
data_types, schedule).
|
||||
2. Server creates an in-memory session, sets up a WS executor so the
|
||||
setup LLM can use file-system tools, does a first directory scrape,
|
||||
and sends back a ``journey_reply`` with the first question.
|
||||
3. FE sends ``journey_message`` frames for each user reply.
|
||||
4. Server appends the user message, calls the LLM (which may read files
|
||||
via tools), and sends back a ``journey_reply``.
|
||||
5. After 3-5 turns the LLM wraps up by emitting a ``prompt_template``
|
||||
block delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
||||
6. Server parses the block, sends ``journey_reply`` with ``done=True``
|
||||
and the template. FE stores it locally.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
|
||||
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||||
from app.core.llm import get_llm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||
|
||||
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||
|
||||
# Sentinel strings used to delimit the LLM-produced prompt_template.
|
||||
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
||||
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
||||
|
||||
# Minimum turns before we consider nudging the LLM to wrap up.
|
||||
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
||||
# Hard cap to avoid infinite loops (safety net, not the primary stopping criterion).
|
||||
_MAX_TURNS: int = 15
|
||||
# Max tool-calling steps per LLM invocation.
|
||||
_MAX_TOOL_STEPS: int = 6
|
||||
|
||||
# ── In-memory session store ───────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class JourneySession:
|
||||
session_id: str
|
||||
user_id: str
|
||||
agent_type: str # "local" | "cloud"
|
||||
directory: str
|
||||
data_types: list[str]
|
||||
history: list[dict[str, Any]] = field(default_factory=list)
|
||||
system_prompt: str = ""
|
||||
created_at: float = field(default_factory=time.monotonic)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS
|
||||
|
||||
|
||||
# session_id → session
|
||||
_sessions: dict[str, JourneySession] = {}
|
||||
|
||||
|
||||
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
||||
"""Retrieve session; return None on missing, expired, or wrong owner."""
|
||||
s = _sessions.get(session_id)
|
||||
if s is None or s.is_expired():
|
||||
_sessions.pop(session_id, None)
|
||||
return None
|
||||
if s.user_id != user_id:
|
||||
return None
|
||||
return s
|
||||
|
||||
|
||||
# ── System prompt builder ─────────────────────────────────────────────────
|
||||
|
||||
_SYSTEM_PROMPT_TEMPLATE = """\
|
||||
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||
Your job is to understand exactly what data the user wants to extract from their
|
||||
local directory and produce a detailed prompt_template that a separate AI will use
|
||||
as its instruction set.
|
||||
|
||||
The extraction agent already has this base behaviour built in:
|
||||
- Reads each file using file-system tools.
|
||||
- Creates records (tasks, notes, timelines, projects) via CRUD tools.
|
||||
- Sets isAiSuggested=1 on every new record.
|
||||
- Only extracts data explicitly present in the files — it never invents information.
|
||||
The user's custom prompt is appended AFTER this base behaviour, so focus on
|
||||
what to look for and how to map it — not on the general extraction mechanics.
|
||||
|
||||
You have access to file-system tools to explore the user's directory:
|
||||
- list_directory: to see folder structure
|
||||
- read_file_content: to peek at file contents
|
||||
- get_file_metadata: to check file info
|
||||
|
||||
The user's configured directory is: {directory}
|
||||
Target data types: {data_types}
|
||||
|
||||
IMPORTANT — project assignment is handled automatically by the main agent runner
|
||||
before the custom prompt is ever used. You MUST NOT ask the user about projects,
|
||||
projectId, or how to link records to projects. Never include projectId logic or
|
||||
project creation instructions in the generated prompt_template.
|
||||
|
||||
Start by exploring the directory to understand its structure. Then ask concise,
|
||||
focused questions one at a time. Cover these topics (not necessarily in this order):
|
||||
1. The type and format of the source content (confirmed by your exploration).
|
||||
2. How fields should be mapped (e.g. filename → task title).
|
||||
3. Priority or status rules (e.g. "urgent" keyword → high priority).
|
||||
4. Any special handling, date extraction, or exclusions.
|
||||
|
||||
Once you reach 90% confidence, output the final prompt_template between these exact
|
||||
markers on their own lines:
|
||||
|
||||
{template_start}
|
||||
<the complete extraction prompt here>
|
||||
{template_end}
|
||||
|
||||
The prompt_template must be a self-contained instruction for an AI that reads files
|
||||
and must perform CRUD operations using tools to create records. It should specify:
|
||||
- What entity types to create (tasks, notes, timelines) — never projects.
|
||||
- How to map file content to record fields (camelCase: title, status, priority,
|
||||
dueDate, content, etc.) — never include projectId.
|
||||
- That isAiSuggested must be set to 1 on every new record.
|
||||
- Concrete examples of mappings based on what you discovered in the directory.
|
||||
|
||||
{existing_section}\
|
||||
Keep asking clarifying questions until you are at least 90% confident you have
|
||||
enough information to generate an accurate prompt_template. Once you reach that
|
||||
confidence level, stop asking and produce the final template immediately.
|
||||
Begin by exploring the directory, then ask your first question.\
|
||||
"""
|
||||
|
||||
|
||||
def _build_system_prompt(
|
||||
directory: str,
|
||||
data_types: list[str],
|
||||
existing_template: str | None = None,
|
||||
) -> str:
|
||||
existing_section = (
|
||||
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
||||
f"---\n{existing_template}\n---\n"
|
||||
if existing_template
|
||||
else ""
|
||||
)
|
||||
return _SYSTEM_PROMPT_TEMPLATE.format(
|
||||
directory=directory,
|
||||
data_types=", ".join(data_types),
|
||||
template_start=_TEMPLATE_START,
|
||||
template_end=_TEMPLATE_END,
|
||||
existing_section=existing_section,
|
||||
)
|
||||
|
||||
|
||||
# ── Template extraction ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _extract_template(text: str) -> str | None:
|
||||
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
|
||||
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
|
||||
return None
|
||||
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
|
||||
end_idx = text.index(_TEMPLATE_END)
|
||||
return text[start_idx:end_idx].strip() or None
|
||||
|
||||
|
||||
# ── LLM call with tool support ───────────────────────────────────────────
|
||||
|
||||
|
||||
def _as_text(content: Any) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
text = item.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "".join(parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
async def _call_llm_with_tools(
|
||||
system_prompt: str,
|
||||
history: list[dict[str, Any]],
|
||||
tools: list[Any],
|
||||
) -> str:
|
||||
"""Build LangChain messages from history and invoke the LLM with tools.
|
||||
|
||||
Handles tool-calling loops: if the LLM calls tools, execute them and
|
||||
continue until a final text response is produced.
|
||||
"""
|
||||
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||
for turn in history:
|
||||
if turn["role"] == "user":
|
||||
messages.append(HumanMessage(content=turn["content"]))
|
||||
else:
|
||||
messages.append(AIMessage(content=turn["content"]))
|
||||
|
||||
llm = get_llm(model=None, temperature=0.4)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||
|
||||
for _ in range(_MAX_TOOL_STEPS):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
messages.append(response)
|
||||
|
||||
if not response.tool_calls:
|
||||
return _as_text(response.content)
|
||||
|
||||
for call in response.tool_calls:
|
||||
call_name = str(call.get("name", ""))
|
||||
call_args = call.get("args", {})
|
||||
logger.info(
|
||||
"agent_setup: journey tool_call name=%s args=%s",
|
||||
call_name,
|
||||
json.dumps(call_args, ensure_ascii=True)[:500],
|
||||
)
|
||||
|
||||
tool_fn = tool_map.get(call_name)
|
||||
if tool_fn is None:
|
||||
tool_output = f"Unknown tool: {call_name}"
|
||||
else:
|
||||
tool_output = await tool_fn.ainvoke(call_args)
|
||||
|
||||
logger.info(
|
||||
"agent_setup: journey tool_result name=%s output=%s",
|
||||
call_name,
|
||||
str(tool_output)[:800],
|
||||
)
|
||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||
|
||||
# Fallback: exceeded max steps.
|
||||
final = await llm.ainvoke(messages)
|
||||
return _as_text(final.content)
|
||||
|
||||
|
||||
# ── Journey handlers (called from device_ws.py) ──────────────────────────
|
||||
|
||||
|
||||
async def handle_journey_start(
|
||||
user_id: str,
|
||||
frame: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Handle a ``journey_start`` WS frame.
|
||||
|
||||
Creates a session, runs the setup LLM with directory exploration,
|
||||
and returns the ``journey_reply`` payload.
|
||||
"""
|
||||
agent_type = frame.get("agent_type", "local")
|
||||
directory = frame.get("directory", "")
|
||||
data_types = frame.get("data_types", [])
|
||||
existing_template = frame.get("existing_template")
|
||||
|
||||
# Use the session_id provided by the FE so the reply matches the
|
||||
# listener key; fall back to a generated one if absent.
|
||||
session_id = frame.get("session_id") or str(uuid.uuid4())
|
||||
system_prompt = _build_system_prompt(directory, data_types, existing_template)
|
||||
|
||||
session = JourneySession(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
agent_type=agent_type,
|
||||
directory=directory,
|
||||
data_types=data_types,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
# The LLM will explore the directory using FILESYSTEM_TOOLS via the
|
||||
# ws_context executor (already set by the WS handler before calling us).
|
||||
# Seed with an initial user message — some providers (e.g. GitHub Copilot)
|
||||
# require at least one user/input message to be present.
|
||||
seed_history: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
||||
]
|
||||
ai_reply = await _call_llm_with_tools(
|
||||
system_prompt=system_prompt,
|
||||
history=seed_history,
|
||||
tools=list(FILESYSTEM_TOOLS),
|
||||
)
|
||||
|
||||
session.history.extend(seed_history)
|
||||
session.history.append({"role": "assistant", "content": ai_reply})
|
||||
_sessions[session_id] = session
|
||||
|
||||
logger.info(
|
||||
"agent_setup: journey session %s started for user %s (directory=%s)",
|
||||
session_id,
|
||||
user_id,
|
||||
directory,
|
||||
)
|
||||
|
||||
# Check if the LLM produced the template on the first turn (unlikely but possible).
|
||||
prompt_template = _extract_template(ai_reply)
|
||||
done = prompt_template is not None
|
||||
|
||||
display_message = ai_reply
|
||||
if done:
|
||||
display_message = (
|
||||
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||
or "Here is your agent configuration. You can save it or continue refining."
|
||||
)
|
||||
_sessions.pop(session_id, None)
|
||||
|
||||
return {
|
||||
"type": "journey_reply",
|
||||
"session_id": session_id,
|
||||
"message": display_message,
|
||||
"done": done,
|
||||
"prompt_template": prompt_template,
|
||||
}
|
||||
|
||||
|
||||
async def handle_journey_message(
|
||||
user_id: str,
|
||||
frame: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Handle a ``journey_message`` WS frame.
|
||||
|
||||
Appends the user message, calls the LLM, and returns the
|
||||
``journey_reply`` payload.
|
||||
"""
|
||||
session_id = frame.get("session_id", "")
|
||||
message = frame.get("message", "")
|
||||
|
||||
session = get_journey_session(session_id, user_id)
|
||||
if session is None:
|
||||
return {
|
||||
"type": "journey_reply",
|
||||
"session_id": session_id,
|
||||
"message": "Journey session not found or expired. Please start a new setup.",
|
||||
"done": True,
|
||||
"prompt_template": None,
|
||||
}
|
||||
|
||||
# Append user turn.
|
||||
session.history.append({"role": "user", "content": message})
|
||||
|
||||
# Call the LLM with tools.
|
||||
ai_reply = await _call_llm_with_tools(
|
||||
system_prompt=session.system_prompt,
|
||||
history=session.history,
|
||||
tools=list(FILESYSTEM_TOOLS),
|
||||
)
|
||||
|
||||
session.history.append({"role": "assistant", "content": ai_reply})
|
||||
|
||||
# Check if the LLM produced the final template.
|
||||
prompt_template = _extract_template(ai_reply)
|
||||
done = prompt_template is not None
|
||||
|
||||
# If the LLM didn't produce a template, nudge it once it has asked enough
|
||||
# questions (>= _MIN_TURNS_BEFORE_NUDGE) or hits the hard safety cap.
|
||||
if not done:
|
||||
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||
if turns >= _MAX_TURNS:
|
||||
nudge_content = (
|
||||
"[System: You have enough information. Please generate the final "
|
||||
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
||||
)
|
||||
session.history.append({"role": "user", "content": nudge_content})
|
||||
|
||||
nudge_reply = await _call_llm_with_tools(
|
||||
system_prompt=session.system_prompt,
|
||||
history=session.history,
|
||||
tools=list(FILESYSTEM_TOOLS),
|
||||
)
|
||||
session.history.append({"role": "assistant", "content": nudge_reply})
|
||||
|
||||
prompt_template = _extract_template(nudge_reply)
|
||||
if prompt_template is not None:
|
||||
done = True
|
||||
ai_reply = nudge_reply
|
||||
|
||||
display_message = ai_reply
|
||||
if done:
|
||||
display_message = (
|
||||
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||
if _TEMPLATE_START in ai_reply
|
||||
else "Here is your agent configuration. You can save it or continue refining."
|
||||
)
|
||||
_sessions.pop(session_id, None)
|
||||
logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id)
|
||||
|
||||
return {
|
||||
"type": "journey_reply",
|
||||
"session_id": session_id,
|
||||
"message": display_message,
|
||||
"done": done,
|
||||
"prompt_template": prompt_template,
|
||||
}
|
||||
@@ -1,222 +0,0 @@
|
||||
"""Agent routes.
|
||||
|
||||
Backend responsibilities are intentionally minimal:
|
||||
GET /agents/catalog — static catalog for UI display
|
||||
POST /agents/can-create — billing eligibility check
|
||||
POST /agents/trigger — trigger a local agent run
|
||||
|
||||
Agent configuration is owned by the Electron app and is not persisted
|
||||
in backend agent-config tables.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.tier_manager import FEATURES
|
||||
from app.core.agent_runner import is_agent_running, run_local_agent
|
||||
from app.core.device_manager import device_manager
|
||||
from app.db import get_session
|
||||
from app.models import AgentRunLog, LocalAgentConfig
|
||||
from app.schemas import (
|
||||
AgentCatalogItem,
|
||||
AgentCreationCheckRequest,
|
||||
AgentCreationCheckResponse,
|
||||
AgentRunLogResponse,
|
||||
AgentTriggerRequest,
|
||||
UserProfile,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/agents", tags=["agents"])
|
||||
|
||||
|
||||
# ── Datetime helpers ──────────────────────────────────────────────────
|
||||
|
||||
def _dt_ms(dt: datetime) -> int:
|
||||
return int(dt.timestamp() * 1000)
|
||||
|
||||
|
||||
def _dt_ms_opt(dt: datetime | None) -> int | None:
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
def _to_data_types(values: list[str]) -> list[str]:
|
||||
normalize = {
|
||||
"task": "tasks", "tasks": "tasks",
|
||||
"note": "notes", "notes": "notes",
|
||||
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
||||
"project": "projects", "projects": "projects",
|
||||
}
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
for v in values:
|
||||
mapped = normalize.get(v)
|
||||
if mapped and mapped not in seen:
|
||||
seen.add(mapped)
|
||||
result.append(mapped)
|
||||
return result
|
||||
|
||||
|
||||
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
||||
return AgentRunLogResponse(
|
||||
id=log.id,
|
||||
agent_id=log.agent_id,
|
||||
agent_type=log.agent_type, # type: ignore[arg-type]
|
||||
status=log.status, # type: ignore[arg-type]
|
||||
items_processed=log.items_processed,
|
||||
items_created=log.items_created,
|
||||
errors=log.errors or [],
|
||||
started_at=_dt_ms(log.started_at),
|
||||
completed_at=_dt_ms_opt(log.completed_at),
|
||||
)
|
||||
|
||||
|
||||
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||
if limit != -1 and current_count >= limit:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||
)
|
||||
return limit
|
||||
|
||||
|
||||
async def _enforce_run_frequency(
|
||||
tier: str,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
|
||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
||||
if limit == -1:
|
||||
return # unlimited
|
||||
|
||||
today_start = datetime.now(timezone.utc).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
result = await db.execute(
|
||||
select(func.count(AgentRunLog.id)).where(
|
||||
AgentRunLog.user_id == user_id,
|
||||
AgentRunLog.started_at >= today_start,
|
||||
)
|
||||
)
|
||||
runs_today: int = result.scalar_one()
|
||||
|
||||
if runs_today >= limit:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
|
||||
)
|
||||
|
||||
|
||||
# ── Catalog ───────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/catalog", response_model=list[AgentCatalogItem])
|
||||
async def get_agent_catalog(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> list[AgentCatalogItem]:
|
||||
"""Return the static list of available agent types and their descriptions."""
|
||||
return [
|
||||
AgentCatalogItem(
|
||||
type="local_directory",
|
||||
name="Local Directory Monitor",
|
||||
description="Watches local directories, extracts data from files using AI",
|
||||
),
|
||||
AgentCatalogItem(
|
||||
type="gmail",
|
||||
name="Gmail Connector",
|
||||
description="Scans Gmail inbox, extracts tasks/notes from emails",
|
||||
),
|
||||
AgentCatalogItem(
|
||||
type="teams",
|
||||
name="Microsoft Teams Connector",
|
||||
description="Monitors Teams messages, extracts action items",
|
||||
),
|
||||
AgentCatalogItem(
|
||||
type="outlook",
|
||||
name="Outlook Connector",
|
||||
description="Scans Outlook inbox, extracts tasks/notes",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@router.post("/can-create", response_model=AgentCreationCheckResponse)
|
||||
async def can_create_agent(
|
||||
body: AgentCreationCheckRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> AgentCreationCheckResponse:
|
||||
"""Check if the user can create one more agent based on billing tier.
|
||||
|
||||
Since configuration is client-owned, the Electron app sends its current
|
||||
active agent count and the backend applies tier limits.
|
||||
"""
|
||||
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
|
||||
allowed = limit == -1 or body.active_agents < limit
|
||||
return AgentCreationCheckResponse(
|
||||
allowed=allowed,
|
||||
tier=current_user.tier,
|
||||
active_agents=body.active_agents,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
||||
async def trigger_agent_run(
|
||||
body: AgentTriggerRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AgentRunLogResponse:
|
||||
"""Trigger a local agent run using client-provided configuration."""
|
||||
_enforce_agent_limit(current_user.tier, body.active_agents)
|
||||
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
||||
|
||||
config = LocalAgentConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=current_user.id,
|
||||
device_id=body.device_id,
|
||||
name="Local Directory Monitor",
|
||||
directory_paths=[body.directory],
|
||||
data_types=_to_data_types(body.what_to_extract),
|
||||
prompt_template=body.custom_agent_prompt,
|
||||
file_extensions=[],
|
||||
schedule_cron=body.batch_interval,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
||||
stable_agent_id = body.agent_id or config.id
|
||||
|
||||
if is_agent_running(stable_agent_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Agent is already running. Only one run per agent is allowed at a time.",
|
||||
)
|
||||
|
||||
run_log = AgentRunLog(
|
||||
agent_id=stable_agent_id,
|
||||
agent_type="local",
|
||||
user_id=current_user.id,
|
||||
status="running",
|
||||
)
|
||||
db.add(run_log)
|
||||
await db.commit()
|
||||
await db.refresh(run_log)
|
||||
|
||||
run_context = {
|
||||
"type": "agent_batch",
|
||||
"run_id": run_log.id,
|
||||
"agent_id": stable_agent_id,
|
||||
}
|
||||
|
||||
asyncio.create_task(
|
||||
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
|
||||
)
|
||||
|
||||
return _to_run_log_response(run_log)
|
||||
@@ -1,235 +0,0 @@
|
||||
"""Auth routes: register, login, refresh, me.
|
||||
|
||||
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
||||
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
||||
SHA-256 hashes so plaintext never reaches the DB.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import bcrypt
|
||||
from cryptography.fernet import Fernet
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.config.settings import settings
|
||||
from app.db import get_session
|
||||
from app.models import RefreshToken, User
|
||||
from app.schemas import AuthTokens, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ── Internal helpers ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def _verify_password(password: str, hashed: str) -> bool:
|
||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||
|
||||
|
||||
def _hash_token(plain_token: str) -> str:
|
||||
"""SHA-256 of the plain refresh token string."""
|
||||
return hashlib.sha256(plain_token.encode()).hexdigest()
|
||||
|
||||
|
||||
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
||||
"""Return (signed JWT, expires_at_ms)."""
|
||||
now = int(time.time())
|
||||
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"email": email,
|
||||
"tier": tier,
|
||||
"exp": exp,
|
||||
"iat": now,
|
||||
}
|
||||
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||
return token, exp * 1000 # ms for client
|
||||
|
||||
|
||||
# ── Request bodies ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _RegisterRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
name: str | None = None
|
||||
surname: str | None = None
|
||||
|
||||
|
||||
class _LoginRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
|
||||
|
||||
class _RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
||||
async def register(
|
||||
body: _RegisterRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Create a new account and return JWT tokens."""
|
||||
existing = await db.execute(select(User).where(User.email == body.email))
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
||||
|
||||
user = User(
|
||||
id=str(uuid.uuid4()),
|
||||
email=body.email,
|
||||
name=body.name,
|
||||
surname=body.surname,
|
||||
password_hash=_hash_password(body.password),
|
||||
tier="free",
|
||||
encryption_key=Fernet.generate_key().decode(),
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush() # get user.id without committing
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=AuthTokens)
|
||||
async def login(
|
||||
body: _LoginRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Validate credentials and return JWT tokens."""
|
||||
result = await db.execute(select(User).where(User.email == body.email))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not _verify_password(body.password, user.password_hash):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=AuthTokens)
|
||||
async def refresh(
|
||||
body: _RefreshRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Rotate a refresh token and return a new token pair."""
|
||||
token_hash = _hash_token(body.refresh_token)
|
||||
result = await db.execute(
|
||||
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||
)
|
||||
rt = result.scalar_one_or_none()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
||||
|
||||
# Rotate: delete old token, issue new one.
|
||||
await db.delete(rt)
|
||||
|
||||
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
new_rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=new_expires,
|
||||
)
|
||||
db.add(new_rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
class _UpdateProfileRequest(BaseModel):
|
||||
name: str | None = None
|
||||
surname: str | None = None
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserProfile)
|
||||
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
||||
"""Return the profile for the authenticated user."""
|
||||
return current_user
|
||||
|
||||
|
||||
@router.put("/me", response_model=UserProfile)
|
||||
async def update_profile(
|
||||
body: _UpdateProfileRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Update the authenticated user's name and surname."""
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
|
||||
if body.name is not None:
|
||||
user.name = body.name
|
||||
if body.surname is not None:
|
||||
user.surname = body.surname
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
return UserProfile(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
name=user.name,
|
||||
surname=user.surname,
|
||||
tier=current_user.tier,
|
||||
)
|
||||
@@ -1,171 +0,0 @@
|
||||
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
|
||||
|
||||
Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the
|
||||
PostgreSQL ``backup_metadata`` table.
|
||||
|
||||
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
|
||||
treating "history" as a ``{backup_id}`` path parameter.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from email.utils import parsedate_to_datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.tier_manager import tier_manager
|
||||
from app.db import get_session
|
||||
from app.models import BackupMetadata as BackupMetadataModel
|
||||
from app.schemas import BackupMetadata, UserProfile
|
||||
from app.storage.blob_store import BlobStore
|
||||
from app.storage.encryption import reject_if_tampered
|
||||
|
||||
router = APIRouter(prefix="/backup", tags=["backup"])
|
||||
|
||||
_blob_store = BlobStore()
|
||||
|
||||
|
||||
async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int:
|
||||
"""Return total backup bytes stored by *user_id*."""
|
||||
result = await db.execute(
|
||||
select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where(
|
||||
BackupMetadataModel.user_id == user_id
|
||||
)
|
||||
)
|
||||
return int(result.scalar_one())
|
||||
|
||||
|
||||
async def _check_backup_quota(
|
||||
user: UserProfile, size_bytes: int, db: AsyncSession
|
||||
) -> None:
|
||||
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
|
||||
current = await _current_backup_bytes(user.id, db)
|
||||
tier_manager.enforce_backup_quota(
|
||||
user.tier, current_bytes=current, additional_bytes=size_bytes
|
||||
)
|
||||
|
||||
|
||||
@router.put("")
|
||||
async def upload_backup(
|
||||
request: Request,
|
||||
x_backup_version: int = Header(..., alias="X-Backup-Version"),
|
||||
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
|
||||
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Upload an E2E-encrypted backup blob.
|
||||
|
||||
Metadata is passed via custom headers; the raw body is the encrypted blob.
|
||||
"""
|
||||
blob = await request.body()
|
||||
reject_if_tampered(blob, x_backup_checksum)
|
||||
await _check_backup_quota(current_user, len(blob), db)
|
||||
|
||||
s3_key = await _blob_store.upload(
|
||||
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
|
||||
)
|
||||
|
||||
row = BackupMetadataModel(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=current_user.id,
|
||||
s3_key=s3_key,
|
||||
version=x_backup_version,
|
||||
timestamp=x_backup_timestamp,
|
||||
checksum=x_backup_checksum,
|
||||
size_bytes=len(blob),
|
||||
)
|
||||
db.add(row)
|
||||
await db.commit()
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/history", response_model=list[BackupMetadata])
|
||||
async def backup_history(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[BackupMetadata]:
|
||||
"""Return backup metadata records for the authenticated user (no blob bytes)."""
|
||||
result = await db.execute(
|
||||
select(BackupMetadataModel)
|
||||
.where(BackupMetadataModel.user_id == current_user.id)
|
||||
.order_by(BackupMetadataModel.timestamp.desc())
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
return [
|
||||
BackupMetadata(
|
||||
version=r.version,
|
||||
timestamp=r.timestamp,
|
||||
checksum=r.checksum,
|
||||
chunk_count=1,
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def download_backup(
|
||||
request: Request,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
|
||||
result = await db.execute(
|
||||
select(BackupMetadataModel)
|
||||
.where(BackupMetadataModel.user_id == current_user.id)
|
||||
.order_by(BackupMetadataModel.timestamp.desc())
|
||||
.limit(1)
|
||||
)
|
||||
latest = result.scalar_one_or_none()
|
||||
if latest is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
|
||||
|
||||
ims_header = request.headers.get("If-Modified-Since")
|
||||
if ims_header:
|
||||
try:
|
||||
ims_dt = parsedate_to_datetime(ims_header)
|
||||
ims_ms = int(ims_dt.timestamp() * 1000)
|
||||
if latest.timestamp <= ims_ms:
|
||||
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
|
||||
except Exception:
|
||||
pass # malformed header — ignore and serve the blob
|
||||
|
||||
blob = await _blob_store.download(current_user.id, latest.s3_key)
|
||||
return Response(
|
||||
content=blob,
|
||||
media_type="application/octet-stream",
|
||||
headers={
|
||||
"X-Backup-Version": str(latest.version),
|
||||
"X-Backup-Timestamp": str(latest.timestamp),
|
||||
"X-Checksum": latest.checksum,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{backup_id}", response_model=dict)
|
||||
async def delete_backup(
|
||||
backup_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Delete a specific backup by ID."""
|
||||
result = await db.execute(
|
||||
select(BackupMetadataModel).where(
|
||||
BackupMetadataModel.id == backup_id,
|
||||
BackupMetadataModel.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
target = result.scalar_one_or_none()
|
||||
if target is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
|
||||
|
||||
await _blob_store.delete(current_user.id, target.s3_key)
|
||||
await db.delete(target)
|
||||
await db.commit()
|
||||
|
||||
return {"ok": True}
|
||||
@@ -1,85 +0,0 @@
|
||||
"""Billing routes: Stripe checkout, webhook, subscription management.
|
||||
|
||||
Business logic lives in ``app.billing.stripe_service.StripeService``.
|
||||
The route layer handles HTTP concerns (request parsing, response shaping)
|
||||
and delegates everything else to the service singleton.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Request, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.stripe_service import stripe_service
|
||||
from app.db import get_session
|
||||
from app.schemas import BillingTier, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||
|
||||
|
||||
# ── Request bodies ─────────────────────────────────────────────────────
|
||||
|
||||
class _CheckoutRequest(BaseModel):
|
||||
tier: BillingTier
|
||||
|
||||
|
||||
# ── Routes ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/checkout", response_model=dict)
|
||||
async def create_checkout(
|
||||
body: _CheckoutRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, str]:
|
||||
"""Create a Stripe checkout session for a tier upgrade.
|
||||
|
||||
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
|
||||
"""
|
||||
url = stripe_service.create_checkout_session(current_user.id, body.tier)
|
||||
return {"checkout_url": url}
|
||||
|
||||
|
||||
@router.post("/webhook", response_model=dict)
|
||||
async def stripe_webhook(
|
||||
request: Request,
|
||||
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Handle Stripe webhook events.
|
||||
|
||||
No JWT auth — authenticated via Stripe signature verification instead.
|
||||
Returns 200 immediately when Stripe is not configured (local dev).
|
||||
"""
|
||||
payload = await request.body()
|
||||
await stripe_service.handle_webhook(payload, stripe_signature, db)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/subscription", response_model=dict)
|
||||
async def get_subscription(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, Any]:
|
||||
"""Return the current subscription info for the authenticated user."""
|
||||
sub = await stripe_service.get_subscription(current_user.id, db)
|
||||
if sub is None:
|
||||
return {
|
||||
"tier": current_user.tier,
|
||||
"status": "free",
|
||||
"stripe_subscription_id": None,
|
||||
"current_period_end": None,
|
||||
}
|
||||
return sub
|
||||
|
||||
|
||||
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
|
||||
async def cancel_subscription(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Cancel the active subscription."""
|
||||
await stripe_service.cancel_subscription(current_user.id, db)
|
||||
return {"ok": True}
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Chat routes: POST /chat (REST fallback).
|
||||
|
||||
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.deep_agent import run_home
|
||||
from app.schemas import ChatRequest, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def chat(
|
||||
body: ChatRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> JSONResponse:
|
||||
"""REST fallback for home chat when websocket streaming is unavailable."""
|
||||
response = await run_home(
|
||||
user_id=current_user.id,
|
||||
message=body.message,
|
||||
context=body.context.model_dump(),
|
||||
)
|
||||
return JSONResponse(content={"response": response})
|
||||
@@ -1,417 +0,0 @@
|
||||
"""Device WebSocket endpoint.
|
||||
|
||||
Persistent connection from Electron devices to the backend.
|
||||
|
||||
WS /api/v1/ws/device?token=<jwt>
|
||||
|
||||
Auth: JWT passed as ``?token=`` query parameter (Bearer header is not
|
||||
available during the WebSocket handshake).
|
||||
|
||||
Protocol:
|
||||
1. Client connects → JWT validated → connection accepted.
|
||||
2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``.
|
||||
3. Backend registers the connection in ``DeviceConnectionManager``.
|
||||
4. Session enters message dispatch loop + heartbeat.
|
||||
|
||||
Incoming frame dispatch:
|
||||
- ``tool_result`` → resolves a pending tool-call Future.
|
||||
- ``journey_start`` → starts a guided setup journey session.
|
||||
- ``journey_message`` → continues a journey conversation.
|
||||
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
||||
- unknown types → logged, ignored.
|
||||
|
||||
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
|
||||
|
||||
On disconnect:
|
||||
- Unregisters from DeviceConnectionManager.
|
||||
- Marks all in-progress AgentRunLog rows for this user as ``error``
|
||||
with message "device disconnected".
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy import update
|
||||
|
||||
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
|
||||
from app.config.settings import settings
|
||||
from app.core.agent_runner import trigger_pending_runs
|
||||
from app.core.deep_agent import run_floating_stream, run_home_stream
|
||||
from app.core.device_manager import device_manager
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.core.output_formatter import StreamFormatter
|
||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||
from app.db import async_session
|
||||
from app.models import AgentRunLog
|
||||
from app.schemas import WsFrameType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
||||
|
||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
||||
|
||||
|
||||
@router.websocket("/device")
|
||||
async def device_ws(websocket: WebSocket) -> None:
|
||||
"""Persistent WebSocket endpoint for Electron device connections.
|
||||
|
||||
Authentication is via ``?token=<jwt>`` query parameter.
|
||||
"""
|
||||
# ── 1. Authenticate before accepting ─────────────────────────────
|
||||
token = websocket.query_params.get("token", "")
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
if not user_id:
|
||||
raise JWTError("missing sub")
|
||||
except JWTError:
|
||||
await websocket.close(code=1008) # Policy Violation
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
# ── 2. Await device_hello frame ───────────────────────────────────
|
||||
try:
|
||||
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
|
||||
except (asyncio.TimeoutError, WebSocketDisconnect):
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
try:
|
||||
hello = json.loads(raw)
|
||||
if hello.get("type") != WsFrameType.device_hello:
|
||||
raise ValueError("expected device_hello as first frame")
|
||||
device_id: str = hello["device_id"]
|
||||
agent_ids: list[str] = hello.get("agent_ids", [])
|
||||
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
||||
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# ── 3. Register connection ────────────────────────────────────────
|
||||
device_manager.register(user_id, device_id, websocket)
|
||||
logger.info(
|
||||
"device_ws: connected user=%s device=%s agents=%s",
|
||||
user_id,
|
||||
device_id,
|
||||
agent_ids,
|
||||
)
|
||||
|
||||
# Trigger any overdue agent runs now that the device is connected.
|
||||
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
||||
|
||||
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
||||
try:
|
||||
await asyncio.gather(
|
||||
_message_loop(websocket, user_id),
|
||||
_heartbeat_loop(websocket),
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc)
|
||||
finally:
|
||||
device_manager.unregister(user_id)
|
||||
logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id)
|
||||
await _mark_runs_disconnected(user_id)
|
||||
|
||||
|
||||
# ── Message dispatch loop ─────────────────────────────────────────────
|
||||
|
||||
async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
||||
"""Receive frames from Electron and dispatch to the appropriate handler."""
|
||||
async for raw in websocket.iter_text():
|
||||
try:
|
||||
frame: dict = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("device_ws: invalid JSON from user=%s", user_id)
|
||||
continue
|
||||
|
||||
frame_type = frame.get("type")
|
||||
|
||||
if frame_type == WsFrameType.tool_result:
|
||||
call_id = frame.get("id")
|
||||
if call_id:
|
||||
device_manager.resolve_pending_call(user_id, call_id, frame)
|
||||
else:
|
||||
logger.warning(
|
||||
"device_ws: tool_result missing id from user=%s", user_id
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.home_request:
|
||||
asyncio.create_task(
|
||||
_handle_home_request(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.floating_request:
|
||||
asyncio.create_task(
|
||||
_handle_floating_request(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.journey_start:
|
||||
asyncio.create_task(
|
||||
_handle_journey_start(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.journey_message:
|
||||
asyncio.create_task(
|
||||
_handle_journey_message(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == "pong":
|
||||
# Heartbeat ack — nothing to do, connection is alive.
|
||||
pass
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"device_ws: unknown frame type %r from user=%s", frame_type, user_id
|
||||
)
|
||||
|
||||
|
||||
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
||||
|
||||
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
||||
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
||||
async def _executor(payload: dict) -> dict:
|
||||
payload["type"] = WsFrameType.tool_call
|
||||
await websocket.send_text(json.dumps(payload))
|
||||
future = device_manager.create_pending_call(user_id, payload["id"])
|
||||
return await future
|
||||
return _executor
|
||||
|
||||
|
||||
async def _handle_home_request(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a home_request frame — streams HomeFormatter output back on the socket."""
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
message: str = frame.get("message", "")
|
||||
session_id: str = frame.get("session_id") or str(uuid4())
|
||||
logger.info(
|
||||
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
message[:200],
|
||||
)
|
||||
|
||||
# ── Memory: enrich context before LLM call ────────────────────────
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
user_id,
|
||||
message,
|
||||
trace_id=request_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"conversation_history": frame.get("conversation_history", []),
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
response_chunks: list[str] = []
|
||||
try:
|
||||
event_stream = run_home_stream(user_id, message, context)
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
# Collect text chunks to build the full response for episode storage
|
||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: home_request failed user=%s req=%s: %s",
|
||||
user_id, request_id, exc,
|
||||
)
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
# ── Memory: store episode after response ──────────────────────────
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.store_episode(
|
||||
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||
)
|
||||
logger.info(
|
||||
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
len("".join(response_chunks)),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_floating_request(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a floating_request frame — streams FloatingFormatter output back on the socket."""
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
message: str = frame.get("message", "")
|
||||
session_id: str = frame.get("session_id") or str(uuid4())
|
||||
scope: dict = frame.get("scope", {})
|
||||
logger.info(
|
||||
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
json.dumps(scope, ensure_ascii=True)[:200],
|
||||
message[:200],
|
||||
)
|
||||
|
||||
# ── Memory: enrich context before LLM call ────────────────────────
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
user_id,
|
||||
message,
|
||||
trace_id=request_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"scope": scope,
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
response_chunks: list[str] = []
|
||||
try:
|
||||
event_stream = run_floating_stream(user_id, message, context)
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: floating_request failed user=%s req=%s: %s",
|
||||
user_id, request_id, exc,
|
||||
)
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
# ── Memory: store episode after response ──────────────────────────
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.store_episode(
|
||||
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||
)
|
||||
logger.info(
|
||||
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
len("".join(response_chunks)),
|
||||
)
|
||||
|
||||
|
||||
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _handle_journey_start(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a journey_start frame — explores directory and sends first question."""
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
try:
|
||||
reply = await handle_journey_start(user_id, frame)
|
||||
await websocket.send_text(json.dumps(reply))
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: journey_start failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": "journey_reply",
|
||||
"session_id": frame.get("session_id", ""),
|
||||
"message": f"Failed to start journey: {exc}",
|
||||
"done": True,
|
||||
"prompt_template": None,
|
||||
}))
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
|
||||
async def _handle_journey_message(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a journey_message frame — continues the journey conversation."""
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
try:
|
||||
reply = await handle_journey_message(user_id, frame)
|
||||
await websocket.send_text(json.dumps(reply))
|
||||
except Exception as exc:
|
||||
session_id = frame.get("session_id", "")
|
||||
logger.error(
|
||||
"device_ws: journey_message failed user=%s session=%s: %s",
|
||||
user_id, session_id, exc,
|
||||
)
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": "journey_reply",
|
||||
"session_id": session_id,
|
||||
"message": f"Journey error: {exc}",
|
||||
"done": True,
|
||||
"prompt_template": None,
|
||||
}))
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
|
||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||
|
||||
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||
"""Send a ping frame every 30 s to keep the connection alive."""
|
||||
while True:
|
||||
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
||||
await websocket.send_text(json.dumps({"type": "ping"}))
|
||||
|
||||
|
||||
# ── Disconnect cleanup ────────────────────────────────────────────────
|
||||
|
||||
async def _mark_runs_disconnected(user_id: str) -> None:
|
||||
"""Mark all in-progress AgentRunLog rows as 'error' for this user."""
|
||||
try:
|
||||
async with async_session() as db:
|
||||
await db.execute(
|
||||
update(AgentRunLog)
|
||||
.where(
|
||||
AgentRunLog.user_id == user_id,
|
||||
AgentRunLog.status == "running",
|
||||
)
|
||||
.values(
|
||||
status="error",
|
||||
errors=["device disconnected"],
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: failed to mark runs as disconnected for user=%s: %s",
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
@@ -1,148 +0,0 @@
|
||||
"""Plugins routes: browse and install plugins from the marketplace.
|
||||
|
||||
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that
|
||||
persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.db import get_session
|
||||
from app.marketplace.plugin_registry import registry
|
||||
from app.marketplace.revenue_share import revenue_share
|
||||
from app.models import PluginInstallation, PluginReview as PluginReviewModel
|
||||
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/plugins", tags=["plugins"])
|
||||
|
||||
|
||||
# ── Tier gate ─────────────────────────────────────────────────────────
|
||||
|
||||
def _require_plugin_tier(user: UserProfile) -> None:
|
||||
"""Raise HTTP 403 for users below Power tier."""
|
||||
if user.tier not in ("power", "team"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Plugin marketplace requires Power tier or above",
|
||||
)
|
||||
|
||||
|
||||
# ── Local detail schema ────────────────────────────────────────────────
|
||||
|
||||
class _PluginDetail(BaseModel):
|
||||
plugin: PluginManifest
|
||||
install_count: int
|
||||
ratings: list[Any]
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("", response_model=PluginListResponse)
|
||||
async def list_plugins(
|
||||
category: str | None = Query(default=None),
|
||||
q: str | None = Query(default=None),
|
||||
page: int = Query(default=1, ge=1),
|
||||
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> PluginListResponse:
|
||||
"""Browse the plugin marketplace. Requires Power tier or above."""
|
||||
_require_plugin_tier(current_user)
|
||||
return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort)
|
||||
|
||||
|
||||
@router.get("/{plugin_id}", response_model=_PluginDetail)
|
||||
async def get_plugin(
|
||||
plugin_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> _PluginDetail:
|
||||
"""Get full plugin details including install count. Requires Power tier or above."""
|
||||
_require_plugin_tier(current_user)
|
||||
entry = await registry.get_plugin(db, plugin_id)
|
||||
if entry is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||
|
||||
# Fetch review ratings for this plugin
|
||||
review_result = await db.execute(
|
||||
select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id)
|
||||
)
|
||||
reviews = review_result.scalars().all()
|
||||
ratings = [
|
||||
{
|
||||
"reviewer_id": r.reviewer_id,
|
||||
"decision": r.decision,
|
||||
"notes": r.notes,
|
||||
"reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None,
|
||||
}
|
||||
for r in reviews
|
||||
]
|
||||
|
||||
return _PluginDetail(
|
||||
plugin=entry["manifest"],
|
||||
install_count=entry["install_count"],
|
||||
ratings=ratings,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/install", response_model=dict)
|
||||
async def install_plugin(
|
||||
plugin_id: str,
|
||||
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, Any]:
|
||||
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
|
||||
|
||||
Requires Power tier or above.
|
||||
"""
|
||||
_require_plugin_tier(current_user)
|
||||
entry = await registry.get_plugin(db, plugin_id)
|
||||
if entry is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||
|
||||
# Record the installation in plugin_installations
|
||||
installation = PluginInstallation(
|
||||
plugin_id=plugin_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
db.add(installation)
|
||||
await db.flush()
|
||||
|
||||
await revenue_share.record_install(
|
||||
db,
|
||||
plugin_id=plugin_id,
|
||||
user_id=current_user.id,
|
||||
amount_cents=entry["manifest"].price_cents,
|
||||
)
|
||||
|
||||
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
|
||||
return {"ok": True, "download_url": download_url}
|
||||
|
||||
|
||||
@router.delete("/{plugin_id}/install", response_model=dict)
|
||||
async def uninstall_plugin(
|
||||
plugin_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Unregister a plugin installation."""
|
||||
result = await db.execute(
|
||||
select(PluginInstallation).where(
|
||||
PluginInstallation.plugin_id == plugin_id,
|
||||
PluginInstallation.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
installation = result.scalar_one_or_none()
|
||||
if installation is not None:
|
||||
await db.delete(installation)
|
||||
await db.commit()
|
||||
await registry.record_uninstall(db, plugin_id)
|
||||
return {"ok": True}
|
||||
@@ -1,195 +0,0 @@
|
||||
"""Storage routes: CRUD for E2E-encrypted cloud records.
|
||||
|
||||
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
|
||||
PostgreSQL ``storage_records`` table.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.tier_manager import tier_manager
|
||||
from app.db import get_session
|
||||
from app.models import StorageRecord
|
||||
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
|
||||
from app.storage.blob_store import BlobStore
|
||||
from app.storage.encryption import reject_if_tampered
|
||||
|
||||
router = APIRouter(prefix="/storage", tags=["storage"])
|
||||
|
||||
_blob_store = BlobStore()
|
||||
|
||||
|
||||
# ── Local response schemas ─────────────────────────────────────────────
|
||||
|
||||
class _CreateResponse(BaseModel):
|
||||
id: str
|
||||
created_at: int
|
||||
|
||||
|
||||
class _RecordMeta(BaseModel):
|
||||
id: str
|
||||
table: str
|
||||
checksum: str
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
|
||||
"""Return total bytes stored by *user_id*."""
|
||||
result = await db.execute(
|
||||
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
|
||||
StorageRecord.user_id == user_id
|
||||
)
|
||||
)
|
||||
return int(result.scalar_one())
|
||||
|
||||
|
||||
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
|
||||
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
|
||||
current = await _current_usage_bytes(user.id, db)
|
||||
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
|
||||
|
||||
|
||||
async def _get_record_for_user(
|
||||
record_id: str, user_id: str, db: AsyncSession
|
||||
) -> StorageRecord:
|
||||
"""Look up a record and verify ownership. Returns 404 on mismatch
|
||||
to prevent user enumeration attacks."""
|
||||
result = await db.execute(
|
||||
select(StorageRecord).where(
|
||||
StorageRecord.id == record_id, StorageRecord.user_id == user_id
|
||||
)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if record is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
|
||||
return record
|
||||
|
||||
|
||||
# ── Routes ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_record(
|
||||
body: StorageRecordCreate,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> _CreateResponse:
|
||||
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
|
||||
reject_if_tampered(body.blob, body.checksum)
|
||||
await _check_quota(current_user, len(body.blob), db)
|
||||
|
||||
record_id = str(uuid.uuid4())
|
||||
|
||||
s3_key = await _blob_store.upload(
|
||||
current_user.id, body.table, record_id, body.blob, body.checksum
|
||||
)
|
||||
|
||||
record = StorageRecord(
|
||||
id=record_id,
|
||||
user_id=current_user.id,
|
||||
table_name=body.table,
|
||||
s3_key=s3_key,
|
||||
checksum=body.checksum,
|
||||
size_bytes=len(body.blob),
|
||||
)
|
||||
db.add(record)
|
||||
await db.commit()
|
||||
await db.refresh(record)
|
||||
|
||||
created_at_ms = int(record.created_at.timestamp() * 1000)
|
||||
return _CreateResponse(id=record_id, created_at=created_at_ms)
|
||||
|
||||
|
||||
@router.get("/records", response_model=list[_RecordMeta])
|
||||
async def list_records(
|
||||
table: str | None = Query(default=None),
|
||||
page: int = Query(default=1, ge=1),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[_RecordMeta]:
|
||||
"""List record metadata for the authenticated user. Blob bytes are never returned."""
|
||||
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
|
||||
if table is not None:
|
||||
query = query.where(StorageRecord.table_name == table)
|
||||
query = query.offset((page - 1) * limit).limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.scalars().all()
|
||||
|
||||
return [
|
||||
_RecordMeta(
|
||||
id=r.id,
|
||||
table=r.table_name,
|
||||
checksum=r.checksum,
|
||||
created_at=int(r.created_at.timestamp() * 1000),
|
||||
updated_at=int(r.updated_at.timestamp() * 1000),
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@router.get("/records/{record_id}")
|
||||
async def download_record(
|
||||
record_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
|
||||
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||
blob = await _blob_store.download(current_user.id, record.s3_key)
|
||||
return Response(
|
||||
content=blob,
|
||||
media_type="application/octet-stream",
|
||||
headers={"X-Checksum": record.checksum},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/records/{record_id}", response_model=dict)
|
||||
async def update_record(
|
||||
record_id: str,
|
||||
body: StorageRecordUpdate,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Replace the blob for an existing record. Verifies checksum before storing."""
|
||||
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||
reject_if_tampered(body.blob, body.checksum)
|
||||
|
||||
delta = len(body.blob) - record.size_bytes
|
||||
if delta > 0:
|
||||
await _check_quota(current_user, delta, db)
|
||||
|
||||
s3_key = await _blob_store.upload(
|
||||
current_user.id, record.table_name, record_id, body.blob, body.checksum
|
||||
)
|
||||
|
||||
record.s3_key = s3_key
|
||||
record.checksum = body.checksum
|
||||
record.size_bytes = len(body.blob)
|
||||
await db.commit()
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.delete("/records/{record_id}", response_model=dict)
|
||||
async def delete_record(
|
||||
record_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Delete a record and its S3 blob."""
|
||||
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||
await _blob_store.delete(current_user.id, record.s3_key)
|
||||
await db.delete(record)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
@@ -1,79 +0,0 @@
|
||||
"""Vectors routes: upsert, search, delete cloud vector store entries, and embed text."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.llm import embed
|
||||
from app.schemas import (
|
||||
UserProfile,
|
||||
VectorSearchRequest,
|
||||
VectorSearchResponse,
|
||||
VectorUpsertRequest,
|
||||
)
|
||||
from app.storage.encryption import reject_if_tampered
|
||||
from app.storage.vector_store import VectorStore
|
||||
|
||||
router = APIRouter(prefix="/storage", tags=["vectors"])
|
||||
|
||||
_vector_store = VectorStore()
|
||||
|
||||
|
||||
class _VectorDeleteRequest(BaseModel):
|
||||
ids: list[str]
|
||||
|
||||
|
||||
class _EmbedRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class _EmbedResponse(BaseModel):
|
||||
vector: list[float]
|
||||
|
||||
|
||||
@router.post("/vectors/upsert", response_model=dict)
|
||||
async def upsert_vectors(
|
||||
body: VectorUpsertRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, int]:
|
||||
"""Verify checksums and store encrypted vectors in the user-scoped namespace."""
|
||||
for item in body.vectors:
|
||||
reject_if_tampered(item.blob, item.checksum)
|
||||
await _vector_store.upsert(current_user.id, body.vectors)
|
||||
return {"upserted": len(body.vectors)}
|
||||
|
||||
|
||||
@router.post("/vectors/search", response_model=VectorSearchResponse)
|
||||
async def search_vectors(
|
||||
body: VectorSearchRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> VectorSearchResponse:
|
||||
"""Search the user-scoped vector namespace with an encrypted query blob."""
|
||||
results = await _vector_store.search(current_user.id, body.query_blob, body.top_k)
|
||||
return VectorSearchResponse(results=results)
|
||||
|
||||
|
||||
@router.delete("/vectors", response_model=dict)
|
||||
async def delete_vectors(
|
||||
body: _VectorDeleteRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, bool]:
|
||||
"""Delete vectors by ID, scoped to the authenticated user."""
|
||||
await _vector_store.delete(current_user.id, body.ids)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.post("/vectors/embed", response_model=_EmbedResponse)
|
||||
async def embed_text(
|
||||
body: _EmbedRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> _EmbedResponse:
|
||||
"""Generate a 1536-dim embedding vector for the given text.
|
||||
|
||||
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
|
||||
Used by backend tools (note_agent) and Electron (vectordb.ts) alike.
|
||||
"""
|
||||
vector = await embed(body.text)
|
||||
return _EmbedResponse(vector=vector)
|
||||
@@ -1,4 +0,0 @@
|
||||
from app.billing.stripe_service import stripe_service
|
||||
from app.billing.tier_manager import tier_manager
|
||||
|
||||
__all__ = ["stripe_service", "tier_manager"]
|
||||
@@ -1,256 +0,0 @@
|
||||
"""Stripe service: checkout sessions, webhook handling, subscription management.
|
||||
|
||||
Subscription records are persisted in the PostgreSQL ``subscriptions`` table.
|
||||
All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not
|
||||
configured, enabling local development without live credentials.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import stripe as stripe_lib
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
# Stripe price IDs per tier — replace with real IDs in production .env
|
||||
TIER_PRICE_IDS: dict[str, str] = {
|
||||
"pro": "price_pro_monthly",
|
||||
"power": "price_power_monthly",
|
||||
"team": "price_team_monthly",
|
||||
}
|
||||
|
||||
|
||||
class StripeService:
|
||||
"""Wraps all Stripe interactions and owns subscription persistence."""
|
||||
|
||||
# ── Internal helpers ────────────────────────────────────────────────
|
||||
|
||||
def _configured(self) -> bool:
|
||||
return bool(settings.STRIPE_SECRET_KEY)
|
||||
|
||||
def _client(self) -> Any:
|
||||
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||
return stripe_lib
|
||||
|
||||
# ── Public API ──────────────────────────────────────────────────────
|
||||
|
||||
def create_checkout_session(
|
||||
self,
|
||||
user_id: str,
|
||||
tier: str,
|
||||
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
||||
cancel_url: str = "https://app.adiuva.app/billing/cancel",
|
||||
) -> str:
|
||||
"""Create a Stripe checkout session and return the URL.
|
||||
|
||||
Returns a stub URL when Stripe is not configured.
|
||||
Raises ``HTTP 400`` for the free tier or an unknown tier.
|
||||
"""
|
||||
if tier == "free":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot create a checkout session for the free tier",
|
||||
)
|
||||
|
||||
price_id = TIER_PRICE_IDS.get(tier)
|
||||
if not price_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unknown tier: {tier}",
|
||||
)
|
||||
|
||||
if not self._configured():
|
||||
return "https://stripe.com/stub-checkout"
|
||||
|
||||
s = self._client()
|
||||
session = s.checkout.Session.create(
|
||||
payment_method_types=["card"],
|
||||
mode="subscription",
|
||||
line_items=[{"price": price_id, "quantity": 1}],
|
||||
success_url=success_url,
|
||||
cancel_url=cancel_url,
|
||||
metadata={"user_id": user_id, "tier": tier},
|
||||
)
|
||||
return session.url
|
||||
|
||||
async def handle_webhook(
|
||||
self,
|
||||
payload: bytes,
|
||||
sig_header: str,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Process a Stripe webhook event.
|
||||
|
||||
Verifies the signature, then dispatches on event type.
|
||||
Raises ``HTTP 400`` on signature mismatch.
|
||||
No-ops when Stripe is not configured.
|
||||
"""
|
||||
if not self._configured():
|
||||
return
|
||||
|
||||
try:
|
||||
s = self._client()
|
||||
event = s.Webhook.construct_event(
|
||||
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except stripe_lib.error.SignatureVerificationError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid Stripe signature",
|
||||
)
|
||||
|
||||
event_type: str = event["type"]
|
||||
data: dict[str, Any] = event["data"]["object"]
|
||||
|
||||
if event_type == "checkout.session.completed":
|
||||
user_id = data.get("metadata", {}).get("user_id")
|
||||
tier = data.get("metadata", {}).get("tier", "free")
|
||||
sub_id = data.get("subscription")
|
||||
period_end_ts = data.get("current_period_end")
|
||||
period_end = (
|
||||
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||
if period_end_ts
|
||||
else None
|
||||
)
|
||||
if user_id:
|
||||
await self._upsert_subscription(
|
||||
db, user_id, sub_id, tier, "active", period_end
|
||||
)
|
||||
|
||||
elif event_type == "customer.subscription.updated":
|
||||
sub_id = data.get("id")
|
||||
new_status = data.get("status", "active")
|
||||
period_end_ts = data.get("current_period_end")
|
||||
period_end = (
|
||||
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||
if period_end_ts
|
||||
else None
|
||||
)
|
||||
if sub_id:
|
||||
await self._update_subscription_by_stripe_id(
|
||||
db, sub_id, status=new_status, current_period_end=period_end
|
||||
)
|
||||
|
||||
elif event_type == "customer.subscription.deleted":
|
||||
sub_id = data.get("id")
|
||||
if sub_id:
|
||||
await self._update_subscription_by_stripe_id(
|
||||
db, sub_id, tier="free", status="canceled"
|
||||
)
|
||||
|
||||
elif event_type == "invoice.payment_failed":
|
||||
sub_id = data.get("subscription")
|
||||
if sub_id:
|
||||
await self._update_subscription_by_stripe_id(
|
||||
db, sub_id, status="past_due"
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
async def get_subscription(
|
||||
self, user_id: str, db: AsyncSession
|
||||
) -> dict[str, Any] | None:
|
||||
"""Return the subscription record for ``user_id``, or ``None`` if absent."""
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None:
|
||||
return None
|
||||
return {
|
||||
"tier": sub.tier,
|
||||
"stripe_subscription_id": sub.stripe_subscription_id,
|
||||
"status": sub.status,
|
||||
"current_period_end": (
|
||||
int(sub.current_period_end.timestamp() * 1000)
|
||||
if sub.current_period_end
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
|
||||
"""Cancel the user's Stripe subscription and downgrade them to free.
|
||||
|
||||
Raises ``HTTP 404`` when no active subscription exists.
|
||||
"""
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None or not sub.stripe_subscription_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No active subscription found",
|
||||
)
|
||||
|
||||
if self._configured():
|
||||
s = self._client()
|
||||
s.Subscription.cancel(sub.stripe_subscription_id)
|
||||
|
||||
sub.tier = "free"
|
||||
sub.status = "canceled"
|
||||
await db.commit()
|
||||
|
||||
# ── Private DB helpers ───────────────────────────────────────────────
|
||||
|
||||
async def _upsert_subscription(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
stripe_subscription_id: str | None,
|
||||
tier: str,
|
||||
sub_status: str,
|
||||
current_period_end: datetime | None,
|
||||
) -> None:
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None:
|
||||
sub = Subscription(user_id=user_id)
|
||||
db.add(sub)
|
||||
sub.stripe_subscription_id = stripe_subscription_id
|
||||
sub.tier = tier
|
||||
sub.status = sub_status
|
||||
sub.current_period_end = current_period_end
|
||||
|
||||
async def _update_subscription_by_stripe_id(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
stripe_subscription_id: str,
|
||||
*,
|
||||
tier: str | None = None,
|
||||
status: str | None = None,
|
||||
current_period_end: datetime | None = None,
|
||||
) -> None:
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(
|
||||
Subscription.stripe_subscription_id == stripe_subscription_id
|
||||
)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None:
|
||||
return
|
||||
if tier is not None:
|
||||
sub.tier = tier
|
||||
if status is not None:
|
||||
sub.status = status
|
||||
if current_period_end is not None:
|
||||
sub.current_period_end = current_period_end
|
||||
|
||||
|
||||
# Module-level singleton shared across the app.
|
||||
stripe_service = StripeService()
|
||||
@@ -1,195 +0,0 @@
|
||||
"""Tier manager: feature matrix and quota enforcement.
|
||||
|
||||
``TierManager`` is the single source of truth for what each billing tier
|
||||
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
|
||||
Quota-enforcement helpers take ``tier`` directly — the caller already has it
|
||||
from ``current_user.tier`` (provided by ``get_current_user``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.schemas import BillingTier
|
||||
|
||||
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
|
||||
FEATURES: dict[str, dict[str, Any]] = {
|
||||
"free": {
|
||||
"agents": 3,
|
||||
"batch_active": 2,
|
||||
"batch_runs_per_day": 5,
|
||||
"cloud_storage_gb": 0,
|
||||
"backup_gb": 0,
|
||||
"providers": 1,
|
||||
"batch_builder": False,
|
||||
"plugin_marketplace": False,
|
||||
"sso": False,
|
||||
},
|
||||
"pro": {
|
||||
"agents": -1, # unlimited
|
||||
"batch_active": 10,
|
||||
"batch_runs_per_day": 50,
|
||||
"cloud_storage_gb": 5,
|
||||
"backup_gb": 5,
|
||||
"providers": -1,
|
||||
"batch_builder": False,
|
||||
"plugin_marketplace": False,
|
||||
"sso": False,
|
||||
},
|
||||
"power": {
|
||||
"agents": -1,
|
||||
"batch_active": -1, # unlimited
|
||||
"batch_runs_per_day": -1, # unlimited
|
||||
"cloud_storage_gb": 25,
|
||||
"backup_gb": 25,
|
||||
"providers": -1,
|
||||
"batch_builder": True,
|
||||
"plugin_marketplace": True,
|
||||
"sso": False,
|
||||
},
|
||||
"team": {
|
||||
"agents": -1,
|
||||
"batch_active": -1,
|
||||
"batch_runs_per_day": -1, # unlimited
|
||||
"cloud_storage_gb": -1, # unlimited
|
||||
"backup_gb": -1, # unlimited
|
||||
"providers": -1,
|
||||
"batch_builder": True,
|
||||
"plugin_marketplace": True,
|
||||
"sso": True,
|
||||
},
|
||||
}
|
||||
|
||||
# Requests-per-minute limit per tier.
|
||||
RATE_LIMITS: dict[str, int] = {
|
||||
"free": 20,
|
||||
"pro": 60,
|
||||
"power": 120,
|
||||
"team": 200,
|
||||
}
|
||||
|
||||
|
||||
class TierManager:
|
||||
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
|
||||
|
||||
# ── Tier lookup ─────────────────────────────────────────────────────
|
||||
|
||||
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||
"""Return the current billing tier for ``user_id`` from the DB.
|
||||
|
||||
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod
|
||||
when no subscription row exists.
|
||||
"""
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
from app.config.settings import settings # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
tier: str | None = result.scalar_one_or_none()
|
||||
if tier is None or tier not in FEATURES:
|
||||
return "power" if settings.ENV == "dev" else "free"
|
||||
return tier # type: ignore[return-value]
|
||||
|
||||
# ── Feature access ───────────────────────────────────────────────────
|
||||
|
||||
def check_feature(self, tier: BillingTier, feature: str) -> bool:
|
||||
"""Return ``True`` if ``tier`` has ``feature`` enabled.
|
||||
|
||||
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
|
||||
"""
|
||||
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
return value != 0
|
||||
|
||||
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
|
||||
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
|
||||
if not self.check_feature(tier, feature):
|
||||
detail = (
|
||||
f"Feature '{feature}' requires {tier_name} tier or above."
|
||||
if tier_name
|
||||
else f"Feature '{feature}' is not available on your current tier."
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
# ── Rate limiting ────────────────────────────────────────────────────
|
||||
|
||||
def get_rate_limit(self, tier: BillingTier) -> int:
|
||||
"""Return the requests-per-minute limit for ``tier``."""
|
||||
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
||||
|
||||
# ── Storage quota ────────────────────────────────────────────────────
|
||||
|
||||
def enforce_quota(
|
||||
self,
|
||||
tier: BillingTier,
|
||||
current_bytes: int = 0,
|
||||
additional_bytes: int = 0,
|
||||
) -> None:
|
||||
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
|
||||
|
||||
``tier`` is the caller's current tier (from ``current_user.tier``).
|
||||
``current_bytes`` is the total bytes already stored (queried by caller).
|
||||
"""
|
||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||
if limit_gb == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Cloud storage is not available on the '{tier}' tier",
|
||||
)
|
||||
if limit_gb == -1:
|
||||
return # unlimited
|
||||
limit_bytes = limit_gb * 1024 ** 3
|
||||
if current_bytes + additional_bytes > limit_bytes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Storage quota exceeded for tier '{tier}'",
|
||||
)
|
||||
|
||||
def enforce_backup_quota(
|
||||
self,
|
||||
tier: BillingTier,
|
||||
current_bytes: int = 0,
|
||||
additional_bytes: int = 0,
|
||||
) -> None:
|
||||
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
|
||||
limit_gb: int = FEATURES[tier]["backup_gb"]
|
||||
if limit_gb == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Backup is not available on the '{tier}' tier",
|
||||
)
|
||||
if limit_gb == -1:
|
||||
return # unlimited
|
||||
limit_bytes = limit_gb * 1024 ** 3
|
||||
if current_bytes + additional_bytes > limit_bytes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Backup quota exceeded for tier '{tier}'",
|
||||
)
|
||||
|
||||
def check_quota(
|
||||
self,
|
||||
tier: BillingTier,
|
||||
current_bytes: int = 0,
|
||||
additional_bytes: int = 0,
|
||||
) -> bool:
|
||||
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
|
||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||
if limit_gb == 0:
|
||||
return False
|
||||
if limit_gb == -1:
|
||||
return True
|
||||
limit_bytes = limit_gb * 1024 ** 3
|
||||
return current_bytes + additional_bytes <= limit_bytes
|
||||
|
||||
|
||||
# Module-level singleton shared across the app.
|
||||
tier_manager = TierManager()
|
||||
@@ -1,62 +0,0 @@
|
||||
from typing import Literal
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
|
||||
JWT_SECRET: str = "change-me-in-production"
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 30
|
||||
|
||||
STRIPE_SECRET_KEY: str = ""
|
||||
STRIPE_WEBHOOK_SECRET: str = ""
|
||||
|
||||
S3_BUCKET: str = ""
|
||||
S3_REGION: str = "us-east-1"
|
||||
S3_ENDPOINT_URL: str = ""
|
||||
AWS_ACCESS_KEY_ID: str = ""
|
||||
AWS_SECRET_ACCESS_KEY: str = ""
|
||||
|
||||
PINECONE_API_KEY: str = ""
|
||||
PINECONE_INDEX: str = "adiuva"
|
||||
QDRANT_URL: str = ""
|
||||
QDRANT_API_KEY: str = ""
|
||||
|
||||
OPENAI_API_KEY: str = ""
|
||||
ANTHROPIC_API_KEY: str = ""
|
||||
GOOGLE_API_KEY: str = ""
|
||||
CEREBRAS_API_KEY: str = ""
|
||||
GITHUB_TOKEN: str = ""
|
||||
|
||||
LLM_MODEL: str = "gpt-4o"
|
||||
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||
|
||||
# GitHub Copilot OAuth token storage directory.
|
||||
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
|
||||
GITHUB_COPILOT_TOKEN_DIR: str = ""
|
||||
|
||||
# OAuth client credentials — used for Gmail and Microsoft (Outlook/Teams) flows.
|
||||
GMAIL_CLIENT_ID: str = ""
|
||||
GMAIL_CLIENT_SECRET: str = ""
|
||||
MS_CLIENT_ID: str = ""
|
||||
MS_CLIENT_SECRET: str = ""
|
||||
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
|
||||
MS_TENANT_ID: str = "common"
|
||||
|
||||
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
||||
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
||||
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
||||
OAUTH_ENCRYPTION_KEY: str = ""
|
||||
|
||||
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
||||
|
||||
ENV: Literal["dev", "prod"] = "dev"
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env", env_file_encoding="utf-8", extra="ignore"
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -1,30 +0,0 @@
|
||||
"""Minimal agent base types retained for compatibility with batch runners."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""Common base for non-chat agents still using the old base contract."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str = "",
|
||||
shared_memory: dict[str, Any] | None = None,
|
||||
vector_store_context: list[str] | None = None,
|
||||
) -> None:
|
||||
self.user_id = user_id
|
||||
self.shared_memory: dict[str, Any] = shared_memory or {}
|
||||
self.vector_store_context: list[str] = vector_store_context or []
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_description(self) -> str: ...
|
||||
|
||||
@property
|
||||
def skills(self) -> list[str]:
|
||||
return []
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,846 +0,0 @@
|
||||
"""Single-agent runners for home and floating chat contexts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import date
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.agents.note_agent import NOTE_TOOLS
|
||||
from app.agents.project_agent import PROJECT_TOOLS
|
||||
from app.agents.task_agent import TASK_TOOLS
|
||||
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||
from app.core.llm import get_llm
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
||||
from app.db import async_session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FloatingDomainType = Literal["task", "timeline", "project", "node"]
|
||||
FloatingDomainSection = Literal["task", "timeline", "note"]
|
||||
|
||||
_HOME_SINGLE_AGENT_SYSTEM = (
|
||||
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||
"Always use tools for factual data retrieval before answering. "
|
||||
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||
"Return markdown and use tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
||||
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>. "
|
||||
"When listing tasks or timelines, each id tag must be on its own line with no prefix/suffix text. "
|
||||
"Never put titles, priorities, or dates on the same line as <task> or <timeline> tags. "
|
||||
"For questions about upcoming timelines (e.g. 'prossimi eventi'), include only future items in the current month unless the user asks a different range. "
|
||||
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
|
||||
)
|
||||
|
||||
_FLOATING_SINGLE_AGENT_SYSTEM = (
|
||||
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||
"Stay focused on the floating scope in context.scope and answer concisely. "
|
||||
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
|
||||
"Always use tools for factual data retrieval before answering. "
|
||||
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||
)
|
||||
|
||||
_FLOATING_DOMAIN_CLASSIFIER_SYSTEM = (
|
||||
"You are a strict domain classifier for websocket floating requests. "
|
||||
"Return ONLY a JSON object with keys: type, id, section. "
|
||||
"Allowed type values: task, timeline, project, node. "
|
||||
"Allowed section values: task, timeline, note, or null. "
|
||||
"Rules: infer from user message intent first; do not blindly trust scope.type. "
|
||||
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
|
||||
"If project id is unknown but context.resolved_project_id exists, use it as id. "
|
||||
"If id is unknown, use null. "
|
||||
"No markdown, no prose, JSON only."
|
||||
)
|
||||
|
||||
|
||||
def _as_text(content: Any) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
text = item.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "".join(parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
def _candidate_tokens(message: str) -> list[str]:
|
||||
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
|
||||
return [token for token in tokens if len(token) >= 3]
|
||||
|
||||
|
||||
async def _resolve_project_id_from_message(message: str) -> str | None:
|
||||
"""Resolve likely project UUID from user message using client project list."""
|
||||
try:
|
||||
result = await execute_on_client(action="select", table="projects")
|
||||
except Exception as exc:
|
||||
logger.warning("deep_agent: project resolve select failed: %s", exc)
|
||||
return None
|
||||
|
||||
rows = result.get("rows", [])
|
||||
if not isinstance(rows, list) or not rows:
|
||||
return None
|
||||
|
||||
tokens = _candidate_tokens(message)
|
||||
scored: list[tuple[int, dict[str, Any]]] = []
|
||||
for row in rows:
|
||||
if not isinstance(row, dict):
|
||||
continue
|
||||
name = str(row.get("name", "")).lower()
|
||||
score = sum(1 for token in tokens if token in name)
|
||||
if score > 0:
|
||||
scored.append((score, row))
|
||||
|
||||
if not scored:
|
||||
return None
|
||||
|
||||
scored.sort(key=lambda item: item[0], reverse=True)
|
||||
top_score = scored[0][0]
|
||||
top_rows = [row for score, row in scored if score == top_score]
|
||||
if len(top_rows) != 1:
|
||||
return None
|
||||
|
||||
project_id = top_rows[0].get("id")
|
||||
return project_id if isinstance(project_id, str) else None
|
||||
|
||||
|
||||
def _needs_project_resolution(message: str) -> bool:
|
||||
lowered = message.lower()
|
||||
return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"])
|
||||
|
||||
|
||||
async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
|
||||
prepared = dict(context)
|
||||
if _needs_project_resolution(message):
|
||||
resolved_project_id = await _resolve_project_id_from_message(message)
|
||||
if resolved_project_id:
|
||||
prepared["resolved_project_id"] = resolved_project_id
|
||||
logger.info("deep_agent: resolved_project_id=%s", resolved_project_id)
|
||||
return prepared
|
||||
|
||||
|
||||
def _all_tools() -> list[Any]:
|
||||
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
|
||||
|
||||
|
||||
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
|
||||
debug = context.get("_debug")
|
||||
if isinstance(debug, dict):
|
||||
request_id = debug.get("request_id")
|
||||
if isinstance(request_id, str) and request_id:
|
||||
return request_id
|
||||
return None
|
||||
|
||||
|
||||
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
|
||||
sanitized = dict(context)
|
||||
sanitized.pop("_debug", None)
|
||||
return sanitized
|
||||
|
||||
|
||||
_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]</\1>")
|
||||
_TIMELINE_DMY_RE = re.compile(r"(?P<d>\d{2})/(?P<m>\d{2})/(?P<y>\d{4})")
|
||||
|
||||
|
||||
def _is_upcoming_timeline_query(message: str) -> bool:
|
||||
lowered = message.lower()
|
||||
has_upcoming = "prossim" in lowered or "upcoming" in lowered or "next" in lowered
|
||||
has_timeline_topic = any(
|
||||
token in lowered
|
||||
for token in ("event", "evento", "eventi", "timeline", "milestone", "scaden")
|
||||
)
|
||||
return has_upcoming and has_timeline_topic
|
||||
|
||||
|
||||
def _timeline_date_in_current_month_or_future(dmy: str) -> bool:
|
||||
match = _TIMELINE_DMY_RE.search(dmy)
|
||||
if not match:
|
||||
return True
|
||||
try:
|
||||
parsed = date(
|
||||
int(match.group("y")),
|
||||
int(match.group("m")),
|
||||
int(match.group("d")),
|
||||
)
|
||||
except ValueError:
|
||||
return True
|
||||
|
||||
today = date.today()
|
||||
return parsed >= today and parsed.year == today.year and parsed.month == today.month
|
||||
|
||||
|
||||
def _normalize_tagged_list_lines(text: str, message: str) -> str:
|
||||
if not text:
|
||||
return text
|
||||
|
||||
upcoming_timeline_only = _is_upcoming_timeline_query(message)
|
||||
output_lines: list[str] = []
|
||||
|
||||
for line in text.splitlines():
|
||||
matches = list(_TAG_LINE_RE.finditer(line))
|
||||
if not matches:
|
||||
output_lines.append(line)
|
||||
continue
|
||||
|
||||
had_non_tag_text = _TAG_LINE_RE.sub("", line).strip(" -\t0123456789.*:)")
|
||||
if not had_non_tag_text and len(matches) == 1:
|
||||
tag_text = matches[0].group(0)
|
||||
if (
|
||||
upcoming_timeline_only
|
||||
and "<timeline>" in tag_text
|
||||
and not _timeline_date_in_current_month_or_future(line)
|
||||
):
|
||||
continue
|
||||
output_lines.append(tag_text)
|
||||
continue
|
||||
|
||||
for match in matches:
|
||||
tag_text = match.group(0)
|
||||
if (
|
||||
upcoming_timeline_only
|
||||
and "<timeline>" in tag_text
|
||||
and not _timeline_date_in_current_month_or_future(line)
|
||||
):
|
||||
continue
|
||||
output_lines.append(tag_text)
|
||||
|
||||
return "\n".join(output_lines)
|
||||
|
||||
|
||||
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
|
||||
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
|
||||
_FLOATING_EMPTY_FALLBACK = "No results found."
|
||||
|
||||
|
||||
def _strip_floating_markup_fragment(text: str) -> str:
|
||||
if not text:
|
||||
return text
|
||||
cleaned = _GENERIC_TAG_RE.sub("", text)
|
||||
return _BRACKETED_ID_RE.sub("", cleaned)
|
||||
|
||||
|
||||
def _strip_floating_markup(text: str) -> str:
|
||||
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
cleaned = _strip_floating_markup_fragment(text)
|
||||
# Collapse excessive spaces introduced by tag/id removal while preserving lines.
|
||||
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
|
||||
return "\n".join(line for line in lines if line)
|
||||
|
||||
|
||||
def _fallback_from_raw_floating_text(raw_text: str) -> str:
|
||||
fallback = _strip_floating_markup_fragment(raw_text or "")
|
||||
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
|
||||
return fallback or _FLOATING_EMPTY_FALLBACK
|
||||
|
||||
|
||||
class _FloatingStreamSanitizer:
|
||||
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pending = ""
|
||||
|
||||
@staticmethod
|
||||
def _split_safe_boundary(text: str) -> tuple[str, str]:
|
||||
boundary = len(text)
|
||||
|
||||
last_lt = text.rfind("<")
|
||||
if last_lt != -1 and ">" not in text[last_lt:]:
|
||||
boundary = min(boundary, last_lt)
|
||||
|
||||
last_lb = text.rfind("[")
|
||||
if last_lb != -1 and "]" not in text[last_lb:]:
|
||||
boundary = min(boundary, last_lb)
|
||||
|
||||
if boundary == len(text):
|
||||
return text, ""
|
||||
return text[:boundary], text[boundary:]
|
||||
|
||||
def feed(self, chunk: str) -> str:
|
||||
combined = f"{self._pending}{chunk}"
|
||||
safe_text, self._pending = self._split_safe_boundary(combined)
|
||||
return _strip_floating_markup_fragment(safe_text)
|
||||
|
||||
def finalize(self) -> str:
|
||||
# Drop dangling unfinished wrappers at the very end.
|
||||
tail = re.sub(r"<[^>\n]*$", "", self._pending)
|
||||
tail = re.sub(r"\[[^\]\n]*$", "", tail)
|
||||
self._pending = ""
|
||||
return _strip_floating_markup_fragment(tail)
|
||||
|
||||
|
||||
def _normalize_memory_label(path_or_label: str) -> str:
|
||||
value = path_or_label.strip()
|
||||
if value.startswith("/memories/"):
|
||||
value = value[len("/memories/"):]
|
||||
value = value.strip("/")
|
||||
return value
|
||||
|
||||
|
||||
def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||
@tool
|
||||
async def memory_list_blocks() -> str:
|
||||
"""List all core memory blocks currently stored for the user."""
|
||||
logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id)
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
blocks = await memory.list_core_blocks(user_id)
|
||||
if not blocks:
|
||||
return "No memory blocks found."
|
||||
lines = [f"- {b['label']}: {b['value']}" for b in blocks]
|
||||
return "Memory blocks:\n" + "\n".join(lines)
|
||||
|
||||
@tool
|
||||
async def memory_get(path_or_label: str) -> str:
|
||||
"""Get one memory block by label or /memories/<label> path."""
|
||||
label = _normalize_memory_label(path_or_label)
|
||||
logger.info("deep_agent: memory_get trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||
if not label:
|
||||
return "Invalid memory label."
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
value = await memory.get_core_block(user_id, label)
|
||||
if value is None:
|
||||
return f"Memory block '{label}' not found."
|
||||
return f"Memory block '{label}':\n{value}"
|
||||
|
||||
@tool
|
||||
async def memory_create(path_or_label: str, value: str) -> str:
|
||||
"""Create or overwrite a memory block value by label or /memories/<label> path."""
|
||||
label = _normalize_memory_label(path_or_label)
|
||||
logger.info("deep_agent: memory_create trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||
if not label:
|
||||
return "Invalid memory label."
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.update_core(user_id, label, value, trace_id=trace_id)
|
||||
return f"Memory block '{label}' saved."
|
||||
|
||||
@tool
|
||||
async def memory_append(path_or_label: str, content: str) -> str:
|
||||
"""Append content to a memory block, creating it if missing."""
|
||||
label = _normalize_memory_label(path_or_label)
|
||||
logger.info("deep_agent: memory_append trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||
if not label:
|
||||
return "Invalid memory label."
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.append_core(user_id, label, content)
|
||||
return f"Memory block '{label}' appended."
|
||||
|
||||
@tool
|
||||
async def memory_replace(path_or_label: str, old_string: str, new_string: str) -> str:
|
||||
"""Replace one exact string in a memory block."""
|
||||
label = _normalize_memory_label(path_or_label)
|
||||
logger.info("deep_agent: memory_replace trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||
if not label:
|
||||
return "Invalid memory label."
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
changed = await memory.replace_core(user_id, label, old_string, new_string)
|
||||
if not changed:
|
||||
return f"No replacement made in '{label}' (old string not found)."
|
||||
return f"Memory block '{label}' updated."
|
||||
|
||||
@tool
|
||||
async def memory_delete(path_or_label: str) -> str:
|
||||
"""Delete a memory block by label or /memories/<label> path."""
|
||||
label = _normalize_memory_label(path_or_label)
|
||||
logger.info("deep_agent: memory_delete trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
||||
if not label:
|
||||
return "Invalid memory label."
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
deleted = await memory.delete_core(user_id, label)
|
||||
if not deleted:
|
||||
return f"Memory block '{label}' not found."
|
||||
return f"Memory block '{label}' deleted."
|
||||
|
||||
@tool
|
||||
async def archival_memory_insert(content: str) -> str:
|
||||
"""Insert a long-term archival memory entry."""
|
||||
logger.info("deep_agent: archival_memory_insert trace=%s user=%s", trace_id or "-", user_id)
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.insert_archival(user_id, content, source="assistant")
|
||||
return "Archival memory saved."
|
||||
|
||||
@tool
|
||||
async def archival_memory_search(query: str, top_k: int = 5) -> str:
|
||||
"""Search long-term archival memory by semantic fallback (keyword currently)."""
|
||||
logger.info("deep_agent: archival_memory_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
results = await memory.search_archival(user_id, query, top_k=top_k)
|
||||
if not results:
|
||||
return "No archival memory results found."
|
||||
lines = [f"- {item}" for item in results]
|
||||
return "Archival memory results:\n" + "\n".join(lines)
|
||||
|
||||
@tool
|
||||
async def conversation_search(query: str, top_k: int = 5) -> str:
|
||||
"""Search recall memory from prior episodic conversation summaries."""
|
||||
logger.info("deep_agent: conversation_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
results = await memory.search_recall(user_id, query, top_k=top_k)
|
||||
if not results:
|
||||
return "No recall memory results found."
|
||||
lines = [f"- {item}" for item in results]
|
||||
return "Recall memory results:\n" + "\n".join(lines)
|
||||
|
||||
return [
|
||||
memory_list_blocks,
|
||||
memory_get,
|
||||
memory_create,
|
||||
memory_append,
|
||||
memory_replace,
|
||||
memory_delete,
|
||||
archival_memory_insert,
|
||||
archival_memory_search,
|
||||
conversation_search,
|
||||
]
|
||||
|
||||
|
||||
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
|
||||
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
|
||||
|
||||
|
||||
def _detect_domain_section(message: str) -> FloatingDomainSection | None:
|
||||
lowered = message.lower()
|
||||
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
|
||||
return "timeline"
|
||||
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
|
||||
return "task"
|
||||
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
|
||||
return "note"
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
|
||||
type_raw = str(payload.get("type") or "").strip().lower()
|
||||
domain_type: FloatingDomainType = "task"
|
||||
if type_raw in {"task", "timeline", "project", "node"}:
|
||||
domain_type = type_raw
|
||||
|
||||
id_value = payload.get("id")
|
||||
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
|
||||
if domain_type == "project" and not domain_id:
|
||||
domain_id = fallback_id
|
||||
|
||||
section_raw = payload.get("section")
|
||||
section: FloatingDomainSection | None = None
|
||||
if isinstance(section_raw, str):
|
||||
section_candidate = section_raw.strip().lower()
|
||||
if section_candidate in {"task", "timeline", "note"}:
|
||||
section = section_candidate
|
||||
|
||||
if domain_type != "project":
|
||||
section = None
|
||||
|
||||
return {
|
||||
"type": domain_type,
|
||||
"id": domain_id,
|
||||
"section": section,
|
||||
}
|
||||
|
||||
|
||||
def _parse_json_object(text: str) -> dict[str, Any] | None:
|
||||
raw = text.strip()
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
return parsed if isinstance(parsed, dict) else None
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
||||
if not match:
|
||||
return None
|
||||
try:
|
||||
parsed = json.loads(match.group(0))
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return parsed if isinstance(parsed, dict) else None
|
||||
|
||||
|
||||
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
|
||||
section = _detect_domain_section(message)
|
||||
scope = context.get("scope") if isinstance(context, dict) else None
|
||||
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
||||
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
||||
|
||||
if isinstance(scope, dict):
|
||||
scope_type = str(scope.get("type") or "").strip().lower()
|
||||
scope_id = scope.get("id")
|
||||
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
|
||||
|
||||
if scope_type in {"task", "tasks"}:
|
||||
return {"type": "task", "id": scope_id_value, "section": None}
|
||||
if scope_type in {"project", "projects"}:
|
||||
project_scope_id = scope_id_value or project_id
|
||||
return {
|
||||
"type": "project",
|
||||
"id": project_scope_id,
|
||||
"section": section,
|
||||
}
|
||||
if scope_type in {"note", "notes"}:
|
||||
return {
|
||||
"type": "node",
|
||||
"id": scope_id_value,
|
||||
"section": None,
|
||||
}
|
||||
if scope_type in {"timeline", "timelines"}:
|
||||
return {"type": "timeline", "id": scope_id_value, "section": None}
|
||||
|
||||
lowered = message.lower()
|
||||
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
|
||||
return {
|
||||
"type": "project",
|
||||
"id": project_id,
|
||||
"section": section,
|
||||
}
|
||||
if section == "timeline":
|
||||
return {"type": "timeline", "id": None, "section": None}
|
||||
if section == "note":
|
||||
return {"type": "node", "id": None, "section": None}
|
||||
return {"type": "task", "id": None, "section": None}
|
||||
|
||||
|
||||
async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[str, str | None]:
|
||||
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
||||
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
||||
|
||||
classifier_context = {
|
||||
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
|
||||
"resolved_project_id": project_id,
|
||||
}
|
||||
|
||||
try:
|
||||
llm = get_llm()
|
||||
response = await llm.ainvoke(
|
||||
[
|
||||
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_SYSTEM),
|
||||
HumanMessage(
|
||||
content=(
|
||||
f"Message:\n{message}\n\n"
|
||||
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
parsed = _parse_json_object(_as_text(response.content))
|
||||
if parsed is not None:
|
||||
domain = _normalize_domain_payload(parsed, project_id)
|
||||
logger.info(
|
||||
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
|
||||
domain.get("type"),
|
||||
domain.get("id"),
|
||||
domain.get("section"),
|
||||
)
|
||||
return domain
|
||||
logger.warning("deep_agent: floating_domain classifier returned non-json output")
|
||||
except Exception as exc:
|
||||
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
|
||||
|
||||
return _infer_floating_domain_rule_based(message, context)
|
||||
|
||||
|
||||
async def _run_single_agent(
|
||||
*,
|
||||
user_id: str,
|
||||
system_prompt: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
max_steps: int = 6,
|
||||
) -> str:
|
||||
trace_id = _trace_id_from_context(context)
|
||||
llm = get_llm()
|
||||
tools = _all_tools_for_user(user_id, trace_id)
|
||||
model_context = _context_for_model(context)
|
||||
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
messages: list[Any] = [
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(
|
||||
content=(
|
||||
f"User message:\n{message}\n\n"
|
||||
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
tool_calls_count = 0
|
||||
collected: list[dict[str, Any]] = []
|
||||
set_tool_result_collector(collected)
|
||||
try:
|
||||
for _ in range(max_steps):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
messages.append(response)
|
||||
|
||||
if not response.tool_calls:
|
||||
final_text = _as_text(response.content)
|
||||
logger.info(
|
||||
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
tool_calls_count,
|
||||
len(final_text),
|
||||
)
|
||||
return final_text
|
||||
|
||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||
for call in response.tool_calls:
|
||||
tool_calls_count += 1
|
||||
call_id = str(call.get("id", ""))
|
||||
call_name = str(call.get("name", ""))
|
||||
call_args = call.get("args", {})
|
||||
logger.info(
|
||||
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
||||
call_id,
|
||||
call_name,
|
||||
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||
)
|
||||
|
||||
tool_fn = tool_map.get(call_name)
|
||||
if tool_fn is None:
|
||||
tool_output = f"Unknown tool: {call_name}"
|
||||
else:
|
||||
tool_output = await tool_fn.ainvoke(call_args)
|
||||
|
||||
logger.info(
|
||||
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
||||
call_id,
|
||||
call_name,
|
||||
str(tool_output)[:1200],
|
||||
)
|
||||
|
||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||
|
||||
final = await llm.ainvoke(messages)
|
||||
final_text = _as_text(final.content)
|
||||
logger.info(
|
||||
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
tool_calls_count,
|
||||
len(final_text),
|
||||
)
|
||||
return final_text
|
||||
finally:
|
||||
clear_tool_result_collector()
|
||||
|
||||
|
||||
async def _run_single_agent_stream(
|
||||
*,
|
||||
user_id: str,
|
||||
system_prompt: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
max_steps: int = 6,
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
trace_id = _trace_id_from_context(context)
|
||||
llm = get_llm()
|
||||
tools = _all_tools_for_user(user_id, trace_id)
|
||||
model_context = _context_for_model(context)
|
||||
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
messages: list[Any] = [
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(
|
||||
content=(
|
||||
f"User message:\n{message}\n\n"
|
||||
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
tool_calls_count = 0
|
||||
streamed_chars = 0
|
||||
collected: list[dict[str, Any]] = []
|
||||
set_tool_result_collector(collected)
|
||||
try:
|
||||
for _ in range(max_steps):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
messages.append(response)
|
||||
|
||||
if not response.tool_calls:
|
||||
emitted_any = False
|
||||
async for chunk in llm.astream(messages):
|
||||
token = _as_text(getattr(chunk, "content", ""))
|
||||
if token:
|
||||
streamed_chars += len(token)
|
||||
emitted_any = True
|
||||
yield "token", token
|
||||
|
||||
# Some providers return final text in `response.content` but stream no chunks.
|
||||
if not emitted_any:
|
||||
fallback_text = _as_text(response.content)
|
||||
if fallback_text:
|
||||
streamed_chars += len(fallback_text)
|
||||
yield "token", fallback_text
|
||||
logger.info(
|
||||
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
tool_calls_count,
|
||||
streamed_chars,
|
||||
)
|
||||
return
|
||||
|
||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||
for call in response.tool_calls:
|
||||
tool_calls_count += 1
|
||||
call_id = str(call.get("id", ""))
|
||||
call_name = str(call.get("name", ""))
|
||||
call_args = call.get("args", {})
|
||||
logger.info(
|
||||
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
||||
call_id,
|
||||
call_name,
|
||||
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||
)
|
||||
|
||||
tool_fn = tool_map.get(call_name)
|
||||
if tool_fn is None:
|
||||
tool_output = f"Unknown tool: {call_name}"
|
||||
else:
|
||||
tool_output = await tool_fn.ainvoke(call_args)
|
||||
|
||||
logger.info(
|
||||
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
||||
call_id,
|
||||
call_name,
|
||||
str(tool_output)[:1200],
|
||||
)
|
||||
|
||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||
|
||||
async for chunk in llm.astream(messages):
|
||||
token = _as_text(getattr(chunk, "content", ""))
|
||||
if token:
|
||||
streamed_chars += len(token)
|
||||
yield "token", token
|
||||
logger.info(
|
||||
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
tool_calls_count,
|
||||
streamed_chars,
|
||||
)
|
||||
finally:
|
||||
clear_tool_result_collector()
|
||||
|
||||
|
||||
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
||||
prepared_context = await _prepare_context(message, context)
|
||||
response = await _run_single_agent(
|
||||
user_id=user_id,
|
||||
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
||||
message=message,
|
||||
context=prepared_context,
|
||||
)
|
||||
return _normalize_tagged_list_lines(response, message)
|
||||
|
||||
|
||||
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
|
||||
prepared_context = await _prepare_context(message, context)
|
||||
domain = await _infer_floating_domain(message, prepared_context)
|
||||
response = await _run_single_agent(
|
||||
user_id=user_id,
|
||||
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
||||
message=message,
|
||||
context=prepared_context,
|
||||
)
|
||||
sanitized = _strip_floating_markup(response)
|
||||
if not sanitized and response:
|
||||
sanitized = _fallback_from_raw_floating_text(response)
|
||||
return sanitized, domain
|
||||
|
||||
|
||||
async def run_home_stream(
|
||||
user_id: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
prepared_context = await _prepare_context(message, context)
|
||||
text_chunks: list[str] = []
|
||||
async for event in _run_single_agent_stream(
|
||||
user_id=user_id,
|
||||
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
||||
message=message,
|
||||
context=prepared_context,
|
||||
):
|
||||
event_type, data = event
|
||||
if event_type != "token":
|
||||
yield event
|
||||
continue
|
||||
text_chunks.append(str(data or ""))
|
||||
|
||||
normalized = _normalize_tagged_list_lines("".join(text_chunks), message)
|
||||
if normalized:
|
||||
yield "token", normalized
|
||||
|
||||
|
||||
async def run_floating_stream(
|
||||
user_id: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
prepared_context = await _prepare_context(message, context)
|
||||
domain = await _infer_floating_domain(message, prepared_context)
|
||||
yield "floating_domain", domain
|
||||
|
||||
sanitizer = _FloatingStreamSanitizer()
|
||||
emitted_sanitized = False
|
||||
raw_chunks: list[str] = []
|
||||
async for event in _run_single_agent_stream(
|
||||
user_id=user_id,
|
||||
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
||||
message=message,
|
||||
context=prepared_context,
|
||||
):
|
||||
event_type, data = event
|
||||
if event_type != "token":
|
||||
yield event
|
||||
continue
|
||||
|
||||
raw_chunk = str(data or "")
|
||||
raw_chunks.append(raw_chunk)
|
||||
sanitized_chunk = sanitizer.feed(raw_chunk)
|
||||
if sanitized_chunk:
|
||||
emitted_sanitized = True
|
||||
yield "token", sanitized_chunk
|
||||
|
||||
tail = sanitizer.finalize()
|
||||
if tail:
|
||||
emitted_sanitized = True
|
||||
yield "token", tail
|
||||
|
||||
if not emitted_sanitized and raw_chunks:
|
||||
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
|
||||
|
||||
|
||||
async def update_core_memory(user_id: str, key: str, value: str) -> None:
|
||||
"""Compatibility helper kept for callers that expect explicit memory update API."""
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.update_core(user_id, key, value)
|
||||
@@ -1,151 +0,0 @@
|
||||
"""Device connection manager.
|
||||
|
||||
Maintains in-memory state for all active Electron → backend WebSocket
|
||||
connections. One connection per user (latest replaces previous).
|
||||
|
||||
The manager handles the **tool-call round-trip** pattern:
|
||||
- Backend sends ``tool_call`` frame → Electron executes the action →
|
||||
returns ``tool_result`` frame.
|
||||
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
||||
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
||||
receive the result dict from Electron.
|
||||
|
||||
This pattern is used by all tools (CRUD, file-system, etc.) via
|
||||
``execute_on_client()`` in ``ws_context.py``.
|
||||
|
||||
The ``device_manager`` module-level singleton is imported by both the
|
||||
device WS route and the agent runner.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceConnection:
|
||||
"""State for a single connected Electron device."""
|
||||
|
||||
ws: WebSocket
|
||||
device_id: str
|
||||
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
||||
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||
|
||||
|
||||
class DeviceConnectionManager:
|
||||
"""Singleton registry of active Electron WebSocket connections.
|
||||
|
||||
Thread/task safety note: asyncio is single-threaded by design. All
|
||||
mutations happen inside await-points on the main event loop, so no
|
||||
locking is required for the in-memory dicts.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._connections: dict[str, DeviceConnection] = {}
|
||||
|
||||
# ── Registration ──────────────────────────────────────────────────
|
||||
|
||||
def register(self, user_id: str, device_id: str, ws: WebSocket) -> None:
|
||||
"""Store the active connection for *user_id*, replacing any previous one."""
|
||||
if user_id in self._connections:
|
||||
old = self._connections[user_id]
|
||||
logger.info(
|
||||
"device_manager: replacing existing connection for user=%s device=%s",
|
||||
user_id,
|
||||
old.device_id,
|
||||
)
|
||||
# Cancel any futures that were waiting on the old connection.
|
||||
for fut in old.pending_calls.values():
|
||||
if not fut.done():
|
||||
fut.cancel()
|
||||
self._connections[user_id] = DeviceConnection(ws=ws, device_id=device_id)
|
||||
logger.info(
|
||||
"device_manager: registered user=%s device=%s", user_id, device_id
|
||||
)
|
||||
|
||||
def unregister(self, user_id: str) -> None:
|
||||
"""Remove the connection for *user_id* and cancel any pending futures."""
|
||||
conn = self._connections.pop(user_id, None)
|
||||
if conn is None:
|
||||
return
|
||||
for fut in conn.pending_calls.values():
|
||||
if not fut.done():
|
||||
fut.cancel()
|
||||
logger.info("device_manager: unregistered user=%s", user_id)
|
||||
|
||||
# ── Presence queries ──────────────────────────────────────────────
|
||||
|
||||
def get_ws(self, user_id: str) -> WebSocket | None:
|
||||
"""Return the active WebSocket for *user_id*, or ``None`` if offline."""
|
||||
conn = self._connections.get(user_id)
|
||||
return conn.ws if conn else None
|
||||
|
||||
def is_online(self, user_id: str, device_id: str | None = None) -> bool:
|
||||
"""Return ``True`` if the user has an active connection.
|
||||
|
||||
If *device_id* is provided also checks that it matches the connected device.
|
||||
"""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn is None:
|
||||
return False
|
||||
if device_id is not None:
|
||||
return conn.device_id == device_id
|
||||
return True
|
||||
|
||||
# ── Frame sending ─────────────────────────────────────────────────
|
||||
|
||||
async def send_frame(self, user_id: str, frame: dict) -> None:
|
||||
"""Send *frame* as a JSON text message to the device.
|
||||
|
||||
Raises ``RuntimeError`` if the user is not connected.
|
||||
"""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn is None:
|
||||
raise RuntimeError(
|
||||
f"send_frame: user {user_id!r} is not connected"
|
||||
)
|
||||
await conn.ws.send_text(json.dumps(frame))
|
||||
|
||||
# ── Tool-call round-trip ──────────────────────────────────────────
|
||||
|
||||
def create_pending_call(
|
||||
self, user_id: str, call_id: str
|
||||
) -> asyncio.Future[dict]:
|
||||
"""Register a Future that will be resolved when the tool_result arrives.
|
||||
|
||||
Raises ``RuntimeError`` if the user is not connected.
|
||||
"""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn is None:
|
||||
raise RuntimeError(
|
||||
f"create_pending_call: user {user_id!r} is not connected"
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
fut: asyncio.Future[dict] = loop.create_future()
|
||||
conn.pending_calls[call_id] = fut
|
||||
return fut
|
||||
|
||||
def resolve_pending_call(
|
||||
self, user_id: str, call_id: str, result: dict
|
||||
) -> None:
|
||||
"""Fulfil the Future registered under *call_id* with the Electron result.
|
||||
|
||||
No-ops if the call_id is unknown (already timed out or cancelled).
|
||||
"""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn is None:
|
||||
return
|
||||
fut = conn.pending_calls.pop(call_id, None)
|
||||
if fut is not None and not fut.done():
|
||||
fut.set_result(result)
|
||||
|
||||
|
||||
# Module-level singleton — import this everywhere.
|
||||
device_manager = DeviceConnectionManager()
|
||||
122
app/core/llm.py
122
app/core/llm.py
@@ -1,122 +0,0 @@
|
||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||
|
||||
Every agent and the orchestrator call ``get_llm()``
|
||||
instead of directly constructing a provider-specific class. The model string
|
||||
follows the `LiteLLM model naming convention
|
||||
<https://docs.litellm.ai/docs/providers>`_:
|
||||
|
||||
* OpenAI: ``gpt-4o``, ``gpt-4o-mini``
|
||||
* Anthropic: ``anthropic/claude-3.5-sonnet``
|
||||
* Google: ``gemini/gemini-pro``
|
||||
* Ollama: ``ollama/llama3``
|
||||
* Bedrock: ``bedrock/anthropic.claude-v2``
|
||||
|
||||
Switch providers by changing **LLM_MODEL** in ``.env``
|
||||
— no code changes required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
import litellm
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
from litellm import get_supported_openai_params # noqa: F401 – validates install
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
# Some models (e.g. gpt-5, o-series) reject unsupported params like temperature.
|
||||
# Drop them silently instead of raising UnsupportedParamsError.
|
||||
litellm.drop_params = True
|
||||
|
||||
# Some provider responses include a plain dict in the `usage` field where a
|
||||
# richer Pydantic model is expected. This warning is noisy but non-fatal.
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||
category=UserWarning,
|
||||
)
|
||||
|
||||
|
||||
def _api_key_for_model(model: str) -> str | None:
|
||||
"""Return the most appropriate API key for the given LiteLLM model string."""
|
||||
if model.startswith("anthropic/"):
|
||||
return settings.ANTHROPIC_API_KEY or None
|
||||
if model.startswith("gemini/") or model.startswith("google/"):
|
||||
return settings.GOOGLE_API_KEY or None
|
||||
if model.startswith("cerebras/"):
|
||||
return settings.CEREBRAS_API_KEY or None
|
||||
if model.startswith("github/"):
|
||||
return settings.GITHUB_TOKEN or None
|
||||
if model.startswith("github_copilot/"):
|
||||
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
|
||||
# No API key is required; returning None lets LiteLLM handle auth.
|
||||
return None
|
||||
# Default: OpenAI-compatible (covers plain model names like "gpt-4o")
|
||||
return settings.OPENAI_API_KEY or None
|
||||
|
||||
|
||||
def get_llm(
|
||||
*,
|
||||
model: str | None = None,
|
||||
temperature: float = 0,
|
||||
) -> ChatOpenAI | ChatLiteLLM:
|
||||
"""Return a LangChain chat model backed by LiteLLM.
|
||||
|
||||
LiteLLM exposes an OpenAI-compatible API, so we use ``ChatOpenAI`` pointed
|
||||
at the LiteLLM proxy endpoint. In practice, ``litellm`` patches the
|
||||
``openai`` client transparently when the model string contains a provider
|
||||
prefix (``anthropic/…``, ``gemini/…``, etc.).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model:
|
||||
LiteLLM model identifier. Defaults to ``settings.LLM_MODEL``.
|
||||
temperature:
|
||||
Sampling temperature. ``0`` = deterministic.
|
||||
"""
|
||||
model = model or settings.LLM_MODEL
|
||||
|
||||
# Point LiteLLM to the custom token directory when configured.
|
||||
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
||||
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
||||
|
||||
if settings.GITHUB_TOKEN:
|
||||
os.environ.setdefault("GITHUB_TOKEN", settings.GITHUB_TOKEN)
|
||||
|
||||
# Use ChatLiteLLM for provider-prefixed models (github_copilot/, anthropic/, etc.)
|
||||
# so LiteLLM handles routing and auth. ChatOpenAI for plain OpenAI model names.
|
||||
if "/" in model:
|
||||
return ChatLiteLLM(model=model, temperature=temperature)
|
||||
|
||||
return ChatOpenAI(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=_api_key_for_model(model),
|
||||
)
|
||||
|
||||
|
||||
async def embed(text: str) -> list[float]:
|
||||
"""Return an embedding vector for *text*.
|
||||
|
||||
Uses ``settings.LLM_EMBED_MODEL`` so the same provider switch in ``.env``
|
||||
(e.g. ``github_copilot/text-embedding-3-small``) applies here without any
|
||||
code changes. Falls back to the raw AsyncOpenAI client for plain OpenAI
|
||||
model names to preserve existing behaviour.
|
||||
"""
|
||||
model = settings.LLM_EMBED_MODEL
|
||||
|
||||
if model.startswith("github_copilot/") or "/" in model:
|
||||
# Use LiteLLM for all provider-prefixed models (Copilot, Bedrock, etc.)
|
||||
# so the provider's auth mechanism is applied correctly.
|
||||
response = await litellm.aembedding(model=model, input=[text])
|
||||
return response.data[0]["embedding"]
|
||||
|
||||
# Plain OpenAI model name — use the raw AsyncOpenAI client (existing path).
|
||||
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
response = await client.embeddings.create(model=model, input=text)
|
||||
return response.data[0].embedding
|
||||
@@ -1,441 +0,0 @@
|
||||
"""Memory Middleware — enrich requests with memory context and store interactions.
|
||||
|
||||
Four-tier memory model (MemGPT-style):
|
||||
core — persistent key/value user preferences, always injected
|
||||
associative — semantic similarity search via pgvector (top-k)
|
||||
episodic — recent session summaries (last N)
|
||||
proactive — behavioral patterns above confidence threshold
|
||||
|
||||
All memory content is encrypted at rest using the per-user Fernet key
|
||||
stored in User.encryption_key. Decryption happens in-memory only.
|
||||
|
||||
Usage:
|
||||
memory = MemoryMiddleware(db_session)
|
||||
context = await memory.enrich_context(user_id, message)
|
||||
# ... run agent ...
|
||||
await memory.store_episode(user_id, session_id, message, response)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import (
|
||||
MemoryAssociative,
|
||||
MemoryCore,
|
||||
MemoryEpisodic,
|
||||
MemoryProactive,
|
||||
User,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tuning constants
|
||||
_ASSOCIATIVE_TOP_K = 5
|
||||
_EPISODIC_RECENT_N = 10
|
||||
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
||||
|
||||
|
||||
class MemoryMiddleware:
|
||||
"""Enrich orchestrator context with memory and persist interactions after."""
|
||||
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self._db = db
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────
|
||||
|
||||
async def enrich_context(
|
||||
self,
|
||||
user_id: str,
|
||||
message: str,
|
||||
trace_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build memory context dict to inject into the orchestrator before LLM call.
|
||||
|
||||
Returns a dict with keys:
|
||||
core_memory — {key: plaintext_value, ...}
|
||||
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
||||
episodic_memory — [plaintext_summary, ...] (most recent N)
|
||||
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
||||
"""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return {}
|
||||
|
||||
core = await self._load_core(user_id, fernet)
|
||||
associative = await self._load_associative(user_id, message, fernet)
|
||||
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||
proactive = await self._load_proactive(user_id, fernet)
|
||||
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
logger.info(
|
||||
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
user_dbg.get("tier") or "-",
|
||||
len(core),
|
||||
len(associative),
|
||||
len(episodic),
|
||||
len(proactive),
|
||||
)
|
||||
|
||||
return {
|
||||
"core_memory": core,
|
||||
"associative_memory": associative,
|
||||
"episodic_memory": episodic,
|
||||
"proactive_hints": proactive,
|
||||
}
|
||||
|
||||
async def store_episode(
|
||||
self,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
message: str,
|
||||
response: str,
|
||||
trace_id: str | None = None,
|
||||
) -> None:
|
||||
"""Summarise and store a completed interaction in episodic memory.
|
||||
|
||||
The summary is a simple heuristic concatenation (no LLM call) to keep
|
||||
latency low. Full LLM summarisation can be added in a later step.
|
||||
"""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return
|
||||
|
||||
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
||||
encrypted = _encrypt(fernet, summary)
|
||||
|
||||
row = MemoryEpisodic(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
summary_encrypted=encrypted,
|
||||
session_id=session_id,
|
||||
)
|
||||
self._db.add(row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
logger.info(
|
||||
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
user_dbg.get("tier") or "-",
|
||||
session_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
||||
"""Upsert a core memory key/value for a user."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return
|
||||
|
||||
encrypted = _encrypt(fernet, value)
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore).where(
|
||||
MemoryCore.user_id == user_id,
|
||||
MemoryCore.key == key,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing is not None:
|
||||
existing.value_encrypted = encrypted
|
||||
else:
|
||||
self._db.add(MemoryCore(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
key=key,
|
||||
value_encrypted=encrypted,
|
||||
))
|
||||
try:
|
||||
await self._db.commit()
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
logger.info(
|
||||
"memory: update_core trace=%s user=%s tier=%s key=%s",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
user_dbg.get("tier") or "-",
|
||||
key,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
||||
"""Return core memory as editable blocks (label/value)."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return []
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore)
|
||||
.where(MemoryCore.user_id == user_id)
|
||||
.order_by(MemoryCore.key.asc())
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: list[dict[str, str]] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append({"label": row.key, "value": plaintext})
|
||||
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
|
||||
return out
|
||||
|
||||
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
||||
"""Return a single core memory block value by label."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return None
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore).where(
|
||||
MemoryCore.user_id == user_id,
|
||||
MemoryCore.key == label,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
|
||||
return None
|
||||
value = _safe_decrypt(fernet, row.value_encrypted)
|
||||
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
|
||||
return value
|
||||
|
||||
async def delete_core(self, user_id: str, label: str) -> bool:
|
||||
"""Delete a core memory block by label. Returns True if deleted."""
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore).where(
|
||||
MemoryCore.user_id == user_id,
|
||||
MemoryCore.key == label,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
|
||||
return False
|
||||
|
||||
await self._db.delete(row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info("memory: delete_core user=%s label=%s", user_id, label)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
||||
await self._db.rollback()
|
||||
return False
|
||||
|
||||
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
||||
"""Append content to a core block, creating it if missing."""
|
||||
current = await self.get_core_block(user_id, label)
|
||||
if current is None:
|
||||
await self.update_core(user_id, label, content)
|
||||
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
|
||||
return
|
||||
await self.update_core(user_id, label, f"{current}\n{content}")
|
||||
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
|
||||
|
||||
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
||||
"""Replace one exact string inside a core block. Returns False if not found."""
|
||||
current = await self.get_core_block(user_id, label)
|
||||
if current is None or old not in current:
|
||||
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
|
||||
return False
|
||||
await self.update_core(user_id, label, current.replace(old, new, 1))
|
||||
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
||||
return True
|
||||
|
||||
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||
"""Insert a long-term archival memory entry."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return
|
||||
|
||||
encrypted = _encrypt(fernet, content)
|
||||
row = MemoryAssociative(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
content_encrypted=encrypted,
|
||||
embedding=None,
|
||||
entity_type=source,
|
||||
entity_id=None,
|
||||
)
|
||||
self._db.add(row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
|
||||
except Exception as exc:
|
||||
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return []
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryAssociative)
|
||||
.where(MemoryAssociative.user_id == user_id)
|
||||
.order_by(MemoryAssociative.updated_at.desc())
|
||||
.limit(100)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
needle = query.strip().lower()
|
||||
out: list[str] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||
if plaintext is None:
|
||||
continue
|
||||
if not needle or needle in plaintext.lower():
|
||||
out.append(plaintext)
|
||||
if len(out) >= max(top_k, 1):
|
||||
break
|
||||
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||
return out
|
||||
|
||||
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||
"""Search recall memory (episodic summaries) by keyword."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return []
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryEpisodic)
|
||||
.where(MemoryEpisodic.user_id == user_id)
|
||||
.order_by(MemoryEpisodic.created_at.desc())
|
||||
.limit(100)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
needle = query.strip().lower()
|
||||
out: list[str] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||
if plaintext is None:
|
||||
continue
|
||||
if not needle or needle in plaintext.lower():
|
||||
out.append(plaintext)
|
||||
if len(out) >= max(top_k, 1):
|
||||
break
|
||||
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||
return out
|
||||
|
||||
# ── Private helpers ───────────────────────────────────────────────────────
|
||||
|
||||
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||
"""Load the user's Fernet key from DB. Returns None if missing."""
|
||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not user.encryption_key:
|
||||
logger.warning("memory: no encryption_key for user=%s", user_id)
|
||||
return None
|
||||
return Fernet(user.encryption_key.encode())
|
||||
|
||||
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
||||
"""Load lightweight user debug fields for trace logs."""
|
||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
return {"tier": None}
|
||||
return {
|
||||
"tier": user.tier,
|
||||
}
|
||||
|
||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: dict[str, str] = {}
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||
if plaintext is not None:
|
||||
out[row.key] = plaintext
|
||||
return out
|
||||
|
||||
async def _load_associative(
|
||||
self, user_id: str, message: str, fernet: Fernet
|
||||
) -> list[str]:
|
||||
"""Load top-k associative memories.
|
||||
|
||||
Production: uses pgvector cosine similarity on the message embedding.
|
||||
Current implementation: keyword-based fallback (no external embedding call)
|
||||
so tests pass without a live OpenAI key.
|
||||
"""
|
||||
result = await self._db.execute(
|
||||
select(MemoryAssociative)
|
||||
.where(MemoryAssociative.user_id == user_id)
|
||||
.order_by(MemoryAssociative.updated_at.desc())
|
||||
.limit(_ASSOCIATIVE_TOP_K)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: list[str] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append(plaintext)
|
||||
return out
|
||||
|
||||
async def _load_episodic(
|
||||
self,
|
||||
user_id: str,
|
||||
fernet: Fernet,
|
||||
session_id: str | None = None,
|
||||
) -> list[str]:
|
||||
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
||||
if session_id:
|
||||
query = query.where(MemoryEpisodic.session_id == session_id)
|
||||
result = await self._db.execute(
|
||||
query
|
||||
.order_by(MemoryEpisodic.created_at.desc())
|
||||
.limit(_EPISODIC_RECENT_N)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: list[str] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append(plaintext)
|
||||
return out
|
||||
|
||||
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||
result = await self._db.execute(
|
||||
select(MemoryProactive)
|
||||
.where(
|
||||
MemoryProactive.user_id == user_id,
|
||||
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
|
||||
)
|
||||
.order_by(MemoryProactive.confidence.desc())
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: list[str] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append(plaintext)
|
||||
return out
|
||||
|
||||
|
||||
# ── Encryption helpers ────────────────────────────────────────────────────────
|
||||
|
||||
def _encrypt(fernet: Fernet, plaintext: str) -> str:
|
||||
return fernet.encrypt(plaintext.encode()).decode()
|
||||
|
||||
|
||||
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
|
||||
"""Decrypt and return plaintext, or None on error (corrupted/wrong key)."""
|
||||
try:
|
||||
return fernet.decrypt(ciphertext.encode()).decode()
|
||||
except (InvalidToken, Exception) as exc:
|
||||
logger.warning("memory: decrypt failed: %s", exc)
|
||||
return None
|
||||
@@ -1,47 +0,0 @@
|
||||
"""Output formatter for deep-agent stream events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
||||
|
||||
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||
|
||||
|
||||
class StreamFormatter:
|
||||
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
self.request_id = request_id
|
||||
|
||||
async def format(
|
||||
self,
|
||||
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||
) -> AsyncGenerator[WsFrame, None]:
|
||||
started = False
|
||||
|
||||
async for event_type, data in event_stream:
|
||||
if event_type == "floating_domain":
|
||||
if isinstance(data, dict):
|
||||
yield WsFloatingDomain(
|
||||
request_id=self.request_id,
|
||||
domain=data,
|
||||
)
|
||||
continue
|
||||
|
||||
if event_type != "token":
|
||||
continue
|
||||
|
||||
if not started:
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
started = True
|
||||
|
||||
text = str(data or "")
|
||||
if text:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=text)
|
||||
|
||||
if not started:
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
yield WsStreamEnd(request_id=self.request_id)
|
||||
@@ -1,92 +0,0 @@
|
||||
"""WebSocket client executor context.
|
||||
|
||||
Holds a per-request async callback that tools call to execute CRUD
|
||||
operations on the Electron client's local SQLite / LanceDB databases.
|
||||
The callback sends a `tool_call` WS frame and awaits the `tool_result`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Coroutine
|
||||
from uuid import uuid4
|
||||
|
||||
# Holds the execute callback for the current WS session.
|
||||
# Set by the chat WS handler before the orchestrator runs; cleared after.
|
||||
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
||||
"_client_executor"
|
||||
)
|
||||
|
||||
# Optional collector that captures raw execute_on_client results.
|
||||
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
|
||||
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||
"_tool_result_collector", default=None
|
||||
)
|
||||
|
||||
|
||||
def set_tool_result_collector(lst: list[dict]) -> None:
|
||||
"""Register *lst* as the collector for this async context."""
|
||||
_tool_result_collector.set(lst)
|
||||
|
||||
|
||||
def clear_tool_result_collector() -> None:
|
||||
"""Clear the collector (best-effort)."""
|
||||
_tool_result_collector.set(None)
|
||||
|
||||
|
||||
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
|
||||
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
|
||||
_client_executor.set(fn)
|
||||
|
||||
|
||||
def clear_client_executor() -> None:
|
||||
"""Remove the executor binding (best-effort; ContextVar resets on task exit)."""
|
||||
try:
|
||||
_client_executor.set(None) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def execute_on_client(
|
||||
action: str,
|
||||
table: str | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
filters: dict[str, Any] | None = None,
|
||||
vector: list[float] | None = None,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a CRUD/vector operation to the Electron client and return the result.
|
||||
|
||||
Builds a ``tool_call`` payload, invokes the per-session WS callback,
|
||||
and returns the ``tool_result`` dict from Electron.
|
||||
|
||||
Raises ``RuntimeError`` if no executor is set (i.e. called outside a WS session).
|
||||
"""
|
||||
callback = _client_executor.get(None)
|
||||
if callback is None:
|
||||
raise RuntimeError(
|
||||
"execute_on_client() called outside a WebSocket session — "
|
||||
"no client executor is set."
|
||||
)
|
||||
|
||||
payload: dict[str, Any] = {"id": str(uuid4()), "action": action}
|
||||
if table is not None:
|
||||
payload["table"] = table
|
||||
if data is not None:
|
||||
payload["data"] = data
|
||||
if filters is not None:
|
||||
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
|
||||
if vector is not None:
|
||||
payload["vector"] = vector
|
||||
if limit is not None:
|
||||
payload["limit"] = limit
|
||||
|
||||
result = await callback(payload)
|
||||
collector = _tool_result_collector.get(None)
|
||||
if collector is not None:
|
||||
collector.append({
|
||||
"action": action,
|
||||
"table": table,
|
||||
"data": result,
|
||||
})
|
||||
return result
|
||||
40
app/db.py
40
app/db.py
@@ -1,40 +0,0 @@
|
||||
"""Database engine, session factory, and base model.
|
||||
|
||||
All app code uses the async SQLAlchemy API. Alembic migrations use the
|
||||
synchronous psycopg2 URL for the CLI (see alembic/env.py).
|
||||
|
||||
Usage in routes:
|
||||
from app.db import get_session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
async def my_route(db: AsyncSession = Depends(get_session)):
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_pre_ping=True,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Shared declarative base for all ORM models."""
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""FastAPI dependency that yields an async DB session per request."""
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
@@ -1,164 +0,0 @@
|
||||
"""Cloud provider integration utilities.
|
||||
|
||||
Provides:
|
||||
* Shared message dataclasses (``EmailMessage``, ``ChatMessage``) used by
|
||||
both the Gmail and MS Graph clients and consumed by ``agent_runner``.
|
||||
* ``get_provider()`` — factory that returns the correct client given a
|
||||
provider name and decrypted OAuth credentials dict.
|
||||
* ``encrypt_token()`` / ``decrypt_token()`` — Fernet-based at-rest
|
||||
encryption for OAuth tokens stored in ``cloud_agent_configs``.
|
||||
|
||||
Encryption rationale
|
||||
--------------------
|
||||
Unlike user content (which is E2E-encrypted client-side and **never**
|
||||
decrypted server-side), OAuth tokens *must* be decrypted server-side
|
||||
because the backend makes provider API calls on behalf of the user.
|
||||
The Fernet key lives solely in ``OAUTH_ENCRYPTION_KEY`` env var — it
|
||||
is never returned to clients.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.integrations.gmail import GmailClient
|
||||
from app.integrations.ms_graph import MSGraphClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Shared message types ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmailMessage:
|
||||
"""A single email message fetched from Gmail or Outlook."""
|
||||
|
||||
id: str
|
||||
subject: str
|
||||
sender: str
|
||||
body_text: str
|
||||
date: datetime
|
||||
labels: list[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def as_text(self) -> str:
|
||||
"""Return a human-readable text representation for LLM extraction."""
|
||||
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
|
||||
return (
|
||||
f"From: {self.sender}\n"
|
||||
f"Date: {date_str}{labels_str}\n"
|
||||
f"Subject: {self.subject}\n\n"
|
||||
f"{self.body_text}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""A single Teams chat or channel message fetched from MS Graph."""
|
||||
|
||||
id: str
|
||||
content: str
|
||||
sender: str
|
||||
channel: str | None
|
||||
date: datetime
|
||||
|
||||
@property
|
||||
def as_text(self) -> str:
|
||||
"""Return a human-readable text representation for LLM extraction."""
|
||||
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||
channel_str = f" [channel: {self.channel}]" if self.channel else ""
|
||||
return (
|
||||
f"From: {self.sender}\n"
|
||||
f"Date: {date_str}{channel_str}\n\n"
|
||||
f"{self.content}"
|
||||
)
|
||||
|
||||
|
||||
# ── Fernet helpers ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
"""Return a ``Fernet`` instance using ``settings.OAUTH_ENCRYPTION_KEY``.
|
||||
|
||||
Raises ``RuntimeError`` if ``OAUTH_ENCRYPTION_KEY`` is not set — callers
|
||||
must ensure this is configured before persisting OAuth tokens.
|
||||
"""
|
||||
key = settings.OAUTH_ENCRYPTION_KEY
|
||||
if not key:
|
||||
raise RuntimeError(
|
||||
"OAUTH_ENCRYPTION_KEY is not set. "
|
||||
"Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
|
||||
)
|
||||
return Fernet(key.encode() if isinstance(key, str) else key)
|
||||
|
||||
|
||||
def encrypt_token(token_info: dict) -> str:
|
||||
"""Fernet-encrypt an OAuth credential dict and return a base64 string.
|
||||
|
||||
Stores the full ``{access_token, refresh_token, token_uri, client_id,
|
||||
client_secret, scopes, expiry}`` dict (or equivalent MSAL shape).
|
||||
|
||||
Raises:
|
||||
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
|
||||
ValueError: ``token_info`` is not a non-empty dict.
|
||||
"""
|
||||
if not isinstance(token_info, dict) or not token_info:
|
||||
raise ValueError("token_info must be a non-empty dict")
|
||||
plaintext = json.dumps(token_info).encode("utf-8")
|
||||
return _get_fernet().encrypt(plaintext).decode("utf-8")
|
||||
|
||||
|
||||
def decrypt_token(encrypted: str) -> dict:
|
||||
"""Decrypt a Fernet-encrypted token string and return the credential dict.
|
||||
|
||||
Raises:
|
||||
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
|
||||
ValueError: The encrypted string is invalid or was encrypted with a
|
||||
different key.
|
||||
"""
|
||||
try:
|
||||
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
|
||||
return json.loads(plaintext)
|
||||
except (InvalidToken, json.JSONDecodeError) as exc:
|
||||
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
|
||||
|
||||
|
||||
# ── Provider factory ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_provider(
|
||||
provider: str,
|
||||
credentials_info: dict,
|
||||
) -> "GmailClient | MSGraphClient":
|
||||
"""Return the correct provider client for *provider*.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
provider:
|
||||
One of ``"gmail"``, ``"outlook"``, ``"teams"``.
|
||||
credentials_info:
|
||||
Decrypted OAuth credential dict (Google or Microsoft shape).
|
||||
|
||||
Raises:
|
||||
ValueError: Unknown provider name.
|
||||
"""
|
||||
if provider == "gmail":
|
||||
from app.integrations.gmail import GmailClient
|
||||
return GmailClient(credentials_info)
|
||||
if provider in {"outlook", "teams"}:
|
||||
from app.integrations.ms_graph import MSGraphClient
|
||||
return MSGraphClient(credentials_info)
|
||||
raise ValueError(
|
||||
f"Unknown cloud provider {provider!r}. "
|
||||
"Supported: 'gmail', 'outlook', 'teams'."
|
||||
)
|
||||
@@ -1,335 +0,0 @@
|
||||
"""Gmail API client for cloud agent integration.
|
||||
|
||||
Wraps the Google Gmail REST API to fetch email messages matching a
|
||||
``filter_config`` dict. Uses the official ``google-api-python-client``
|
||||
library (synchronous) wrapped in ``asyncio.to_thread()`` to avoid
|
||||
blocking the event loop.
|
||||
|
||||
Token refresh is handled transparently: when the stored access token has
|
||||
expired, ``google.auth.transport.requests.Request`` will use the refresh
|
||||
token to obtain a fresh one. The caller is responsible for persisting
|
||||
any refreshed credentials back to ``CloudAgentConfig.oauth_token_encrypted``
|
||||
(see ``agent_runner.run_cloud_agent``).
|
||||
|
||||
Credential dict shape (Google OAuth2):
|
||||
{
|
||||
"token": "<access_token>",
|
||||
"refresh_token": "<refresh_token>",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"client_id": "<client_id>",
|
||||
"client_secret": "<client_secret>",
|
||||
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
"expiry": "2025-01-01T00:00:00Z" # optional ISO-8601
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import email
|
||||
import html
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from app.integrations import EmailMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Gmail search date format — e.g. "after:2025/01/01"
|
||||
_GMAIL_DATE_FMT = "%Y/%m/%d"
|
||||
|
||||
# Maximum characters of body text forwarded to the LLM.
|
||||
_BODY_TRUNCATE = 8_000
|
||||
|
||||
# Maximum messages retrieved per run (prevents runaway quota usage).
|
||||
_MAX_MESSAGES = 200
|
||||
|
||||
|
||||
def _build_gmail_query(
|
||||
filter_config: dict[str, Any] | None,
|
||||
since: datetime | None,
|
||||
) -> str:
|
||||
"""Build a Gmail search query string from *filter_config* and *since*.
|
||||
|
||||
Supported ``filter_config`` keys:
|
||||
labels (list[str]): Gmail label names, e.g. ``["INBOX", "work"]``
|
||||
senders (list[str]): Sender addresses or domains to include
|
||||
date_range (dict): ``{from: "<YYYY-MM-DD>", to: "<YYYY-MM-DD>"}``
|
||||
|
||||
A hard ``since`` date (from last run) always overrides ``date_range.from``
|
||||
when it is earlier.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
cfg = filter_config or {}
|
||||
|
||||
# Labels — joined with OR when multiple given.
|
||||
labels: list[str] = cfg.get("labels", [])
|
||||
if labels:
|
||||
if len(labels) == 1:
|
||||
parts.append(f"label:{labels[0]}")
|
||||
else:
|
||||
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
|
||||
parts.append(f"({label_expr})")
|
||||
|
||||
# Senders — each prefixed with "from:".
|
||||
senders: list[str] = cfg.get("senders", [])
|
||||
for sender in senders:
|
||||
parts.append(f"from:{sender}")
|
||||
|
||||
# Date range.
|
||||
date_range: dict = cfg.get("date_range", {})
|
||||
from_str: str | None = date_range.get("from")
|
||||
to_str: str | None = date_range.get("to")
|
||||
|
||||
# Determine effective "from" date: most recent of filter_config.date_range.from and since.
|
||||
effective_since: datetime | None = since
|
||||
if from_str:
|
||||
try:
|
||||
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||
if cfg_since.tzinfo is None:
|
||||
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||
if effective_since is None or cfg_since > effective_since:
|
||||
effective_since = cfg_since
|
||||
except ValueError:
|
||||
logger.warning("gmail: invalid date_range.from %r — ignoring", from_str)
|
||||
|
||||
if effective_since:
|
||||
parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}")
|
||||
|
||||
if to_str:
|
||||
try:
|
||||
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||
parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}")
|
||||
except ValueError:
|
||||
logger.warning("gmail: invalid date_range.to %r — ignoring", to_str)
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _strip_html(raw_html: str) -> str:
|
||||
"""Remove HTML tags and decode entities to get plain text."""
|
||||
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
|
||||
decoded = html.unescape(no_tags)
|
||||
return re.sub(r"\s+", " ", decoded).strip()
|
||||
|
||||
|
||||
def _parse_body(payload: dict[str, Any]) -> str:
|
||||
"""Recursively extract the plain-text body from a Gmail message payload.
|
||||
|
||||
Prefers ``text/plain``; falls back to ``text/html`` (stripped of tags).
|
||||
Returns an empty string if no body can be extracted.
|
||||
"""
|
||||
mime_type: str = payload.get("mimeType", "")
|
||||
body: dict = payload.get("body", {})
|
||||
parts: list[dict] = payload.get("parts", [])
|
||||
|
||||
if mime_type == "text/plain":
|
||||
data = body.get("data", "")
|
||||
if data:
|
||||
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||
return ""
|
||||
|
||||
if mime_type == "text/html":
|
||||
data = body.get("data", "")
|
||||
if data:
|
||||
raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||
return _strip_html(raw)
|
||||
return ""
|
||||
|
||||
# Multipart — prefer text/plain part, fall back to text/html.
|
||||
plain_fallback = ""
|
||||
for part in parts:
|
||||
part_mime = part.get("mimeType", "")
|
||||
if part_mime == "text/plain":
|
||||
return _parse_body(part)
|
||||
if part_mime == "text/html" and not plain_fallback:
|
||||
plain_fallback = _parse_body(part)
|
||||
if part_mime.startswith("multipart/"):
|
||||
nested = _parse_body(part)
|
||||
if nested:
|
||||
return nested
|
||||
return plain_fallback
|
||||
|
||||
|
||||
def _parse_date(raw: str) -> datetime:
|
||||
"""Parse an RFC 2822 email date header into a UTC ``datetime``."""
|
||||
try:
|
||||
parsed = email.utils.parsedate_to_datetime(raw)
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed.astimezone(timezone.utc)
|
||||
except Exception:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class GmailClient:
|
||||
"""Fetch email messages from a Gmail account via the Gmail REST API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
credentials_info:
|
||||
Decrypted OAuth2 credential dict. Must contain at minimum
|
||||
``token`` (access token) or ``refresh_token`` + ``token_uri`` +
|
||||
``client_id`` + ``client_secret``.
|
||||
"""
|
||||
|
||||
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
self._credentials_info = credentials_info
|
||||
expiry_str: str | None = credentials_info.get("expiry")
|
||||
expiry: datetime | None = None
|
||||
if expiry_str:
|
||||
try:
|
||||
expiry = datetime.fromisoformat(
|
||||
expiry_str.replace("Z", "+00:00")
|
||||
).replace(tzinfo=timezone.utc)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
self._credentials = Credentials(
|
||||
token=credentials_info.get("token"),
|
||||
refresh_token=credentials_info.get("refresh_token"),
|
||||
token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"),
|
||||
client_id=credentials_info.get("client_id"),
|
||||
client_secret=credentials_info.get("client_secret"),
|
||||
scopes=credentials_info.get("scopes"),
|
||||
expiry=expiry,
|
||||
)
|
||||
|
||||
# ── Public API ─────────────────────────────────────────────────────────
|
||||
|
||||
async def fetch_messages(
|
||||
self,
|
||||
filter_config: dict[str, Any] | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> list[EmailMessage]:
|
||||
"""Return up to ``_MAX_MESSAGES`` emails matching *filter_config*.
|
||||
|
||||
Runs the synchronous Google API calls inside ``asyncio.to_thread()``
|
||||
to avoid blocking the async event loop.
|
||||
|
||||
Token refresh is performed automatically when the access token has
|
||||
expired. After the call, ``self.refreshed_credentials`` may be
|
||||
consulted to detect whether new credentials should be persisted.
|
||||
"""
|
||||
query = _build_gmail_query(filter_config, since)
|
||||
logger.debug("gmail: executing search query %r", query)
|
||||
return await asyncio.to_thread(self._fetch_sync, query)
|
||||
|
||||
@property
|
||||
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||
"""Return updated credential dict if the access token was refreshed.
|
||||
|
||||
If the credentials were refreshed during ``fetch_messages()``, returns
|
||||
a new dict that should be re-encrypted and written back to the DB.
|
||||
Returns ``None`` if no refresh occurred.
|
||||
"""
|
||||
creds = self._credentials
|
||||
if not creds.valid and creds.expired:
|
||||
return None
|
||||
# Check whether the token changed from what was stored.
|
||||
if creds.token != self._credentials_info.get("token"):
|
||||
result = {
|
||||
"token": creds.token,
|
||||
"refresh_token": creds.refresh_token,
|
||||
"token_uri": creds.token_uri,
|
||||
"client_id": creds.client_id,
|
||||
"client_secret": creds.client_secret,
|
||||
"scopes": list(creds.scopes or []),
|
||||
}
|
||||
if creds.expiry:
|
||||
result["expiry"] = creds.expiry.isoformat()
|
||||
return result
|
||||
return None
|
||||
|
||||
# ── Internal sync worker ───────────────────────────────────────────────
|
||||
|
||||
def _fetch_sync(self, query: str) -> list[EmailMessage]:
|
||||
"""Synchronous worker — called inside ``asyncio.to_thread()``."""
|
||||
import googleapiclient.discovery
|
||||
import googleapiclient.errors
|
||||
from google.auth.transport.requests import Request
|
||||
|
||||
# Refresh token if needed before building the service.
|
||||
if self._credentials.expired and self._credentials.refresh_token:
|
||||
try:
|
||||
self._credentials.refresh(Request())
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc
|
||||
|
||||
service = googleapiclient.discovery.build(
|
||||
"gmail", "v1", credentials=self._credentials, cache_discovery=False
|
||||
)
|
||||
user_api = service.users() # type: ignore[attr-defined]
|
||||
|
||||
# ── List matching message IDs ──────────────────────────────────────
|
||||
ids: list[str] = []
|
||||
page_token: str | None = None
|
||||
while len(ids) < _MAX_MESSAGES:
|
||||
batch_size = min(100, _MAX_MESSAGES - len(ids))
|
||||
kwargs: dict[str, Any] = {
|
||||
"userId": "me",
|
||||
"maxResults": batch_size,
|
||||
}
|
||||
if query:
|
||||
kwargs["q"] = query
|
||||
if page_token:
|
||||
kwargs["pageToken"] = page_token
|
||||
|
||||
try:
|
||||
resp = user_api.messages().list(**kwargs).execute()
|
||||
except googleapiclient.errors.HttpError as exc:
|
||||
raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc
|
||||
|
||||
for msg in resp.get("messages", []):
|
||||
ids.append(msg["id"])
|
||||
|
||||
page_token = resp.get("nextPageToken")
|
||||
if not page_token:
|
||||
break
|
||||
|
||||
if not ids:
|
||||
logger.debug("gmail: no messages matched query %r", query)
|
||||
return []
|
||||
|
||||
logger.info("gmail: fetching %d message(s)", len(ids))
|
||||
|
||||
# ── Fetch individual message details ──────────────────────────────
|
||||
messages: list[EmailMessage] = []
|
||||
for msg_id in ids:
|
||||
try:
|
||||
msg = user_api.messages().get(
|
||||
userId="me", id=msg_id, format="full"
|
||||
).execute()
|
||||
|
||||
headers: dict[str, str] = {
|
||||
h["name"].lower(): h["value"]
|
||||
for h in msg.get("payload", {}).get("headers", [])
|
||||
}
|
||||
subject = headers.get("subject", "(no subject)")
|
||||
sender = headers.get("from", "unknown")
|
||||
date_raw = headers.get("date", "")
|
||||
date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc)
|
||||
|
||||
body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE]
|
||||
labels = msg.get("labelIds", [])
|
||||
|
||||
messages.append(EmailMessage(
|
||||
id=msg_id,
|
||||
subject=subject,
|
||||
sender=sender,
|
||||
body_text=body_text,
|
||||
date=date,
|
||||
labels=labels,
|
||||
))
|
||||
except googleapiclient.errors.HttpError as exc:
|
||||
logger.warning("gmail: skipping message %s — HTTP error: %s", msg_id, exc)
|
||||
except Exception as exc:
|
||||
logger.warning("gmail: skipping message %s — unexpected error: %s", msg_id, exc)
|
||||
|
||||
logger.info("gmail: returned %d message(s)", len(messages))
|
||||
return messages
|
||||
@@ -1,352 +0,0 @@
|
||||
"""Microsoft Graph API client for Outlook and Teams cloud agent integration.
|
||||
|
||||
Handles two data sources:
|
||||
|
||||
* **Outlook email** (``provider="outlook"``) — ``fetch_emails()`` calls
|
||||
``/me/messages`` with an OData ``$filter`` built from ``filter_config``.
|
||||
* **Teams messages** (``provider="teams"``) — ``fetch_messages()`` calls
|
||||
``/me/chats/getAllMessages`` filtered by date.
|
||||
|
||||
Authentication uses MSAL ``PublicClientApplication`` to acquire a token
|
||||
from a stored refresh token. The ``httpx.AsyncClient`` (already a project
|
||||
dependency) is used for all API calls.
|
||||
|
||||
Credential dict shape (Microsoft OAuth2 / MSAL):
|
||||
{
|
||||
"access_token": "<access_token>",
|
||||
"refresh_token": "<refresh_token>",
|
||||
"token_type": "Bearer",
|
||||
"scope": "Mail.Read ChannelMessage.Read.All offline_access",
|
||||
"expires_in": 3600
|
||||
}
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.integrations import ChatMessage, EmailMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
# Max items fetched per run.
|
||||
_MAX_EMAILS = 200
|
||||
_MAX_MESSAGES = 200
|
||||
|
||||
# Max characters of body forwarded to the LLM.
|
||||
_BODY_TRUNCATE = 8_000
|
||||
|
||||
|
||||
def _strip_html(raw: str) -> str:
|
||||
"""Strip HTML tags and collapse whitespace."""
|
||||
no_tags = re.sub(r"<[^>]+>", " ", raw)
|
||||
import html as _html
|
||||
decoded = _html.unescape(no_tags)
|
||||
return re.sub(r"\s+", " ", decoded).strip()
|
||||
|
||||
|
||||
def _odata_datetime(dt: datetime) -> str:
|
||||
"""Format a datetime as an OData datetime literal (UTC, ISO 8601)."""
|
||||
utc = dt.astimezone(timezone.utc)
|
||||
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
|
||||
def _build_email_filter(
|
||||
filter_config: dict[str, Any] | None,
|
||||
since: datetime | None,
|
||||
) -> str:
|
||||
"""Build an OData ``$filter`` expression for the ``/me/messages`` endpoint.
|
||||
|
||||
Supported ``filter_config`` keys:
|
||||
senders (list[str]): Sender email addresses.
|
||||
date_range (dict): ``{from: "<ISO-8601>", to: "<ISO-8601>"}``
|
||||
folders (list[str]): Folder display names (not directly filterable
|
||||
via OData, so ignored here — callers iterate
|
||||
folder IDs separately if needed; listed for
|
||||
completeness).
|
||||
|
||||
A hard ``since`` date always overrides ``date_range.from`` when it is
|
||||
earlier.
|
||||
"""
|
||||
clauses: list[str] = []
|
||||
cfg = filter_config or {}
|
||||
|
||||
# Senders.
|
||||
senders: list[str] = cfg.get("senders", [])
|
||||
if senders:
|
||||
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
|
||||
clauses.append("(" + " or ".join(sender_clauses) + ")")
|
||||
|
||||
# Date range.
|
||||
date_range: dict = cfg.get("date_range", {})
|
||||
from_str: str | None = date_range.get("from")
|
||||
|
||||
effective_since: datetime | None = since
|
||||
if from_str:
|
||||
try:
|
||||
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||
if cfg_since.tzinfo is None:
|
||||
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||
if effective_since is None or cfg_since > effective_since:
|
||||
effective_since = cfg_since
|
||||
except ValueError:
|
||||
logger.warning("ms_graph: invalid date_range.from %r — ignoring", from_str)
|
||||
|
||||
if effective_since:
|
||||
clauses.append(f"receivedDateTime ge {_odata_datetime(effective_since)}")
|
||||
|
||||
to_str: str | None = date_range.get("to")
|
||||
if to_str:
|
||||
try:
|
||||
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||
if to_dt.tzinfo is None:
|
||||
to_dt = to_dt.replace(tzinfo=timezone.utc)
|
||||
clauses.append(f"receivedDateTime le {_odata_datetime(to_dt)}")
|
||||
except ValueError:
|
||||
logger.warning("ms_graph: invalid date_range.to %r — ignoring", to_str)
|
||||
|
||||
return " and ".join(clauses)
|
||||
|
||||
|
||||
class MSGraphClient:
|
||||
"""Fetch emails and Teams messages via the Microsoft Graph REST API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
credentials_info:
|
||||
Decrypted MSAL credential dict.
|
||||
"""
|
||||
|
||||
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||
self._credentials_info = credentials_info
|
||||
self._access_token: str = credentials_info.get("access_token", "")
|
||||
self._original_access_token: str = self._access_token
|
||||
self._refresh_token: str | None = credentials_info.get("refresh_token")
|
||||
|
||||
# ── Token management ───────────────────────────────────────────────────
|
||||
|
||||
def _auth_headers(self) -> dict[str, str]:
|
||||
return {"Authorization": f"Bearer {self._access_token}"}
|
||||
|
||||
async def _refresh_access_token(self) -> None:
|
||||
"""Use MSAL to exchange the refresh token for a fresh access token.
|
||||
|
||||
Updates ``self._access_token`` and ``self._credentials_info`` in-place.
|
||||
|
||||
Raises:
|
||||
RuntimeError: MSAL reports an auth error.
|
||||
"""
|
||||
import msal
|
||||
|
||||
app = msal.ConfidentialClientApplication(
|
||||
client_id=settings.MS_CLIENT_ID,
|
||||
client_credential=settings.MS_CLIENT_SECRET,
|
||||
authority=f"https://login.microsoftonline.com/{settings.MS_TENANT_ID}",
|
||||
)
|
||||
scopes: list[str] = self._credentials_info.get("scope", "").split()
|
||||
if not scopes:
|
||||
scopes = ["https://graph.microsoft.com/.default"]
|
||||
|
||||
result = app.acquire_token_by_refresh_token(
|
||||
self._refresh_token,
|
||||
scopes=scopes,
|
||||
)
|
||||
if "access_token" not in result:
|
||||
error = result.get("error_description", result.get("error", "unknown"))
|
||||
raise RuntimeError(f"MS Graph token refresh failed: {error}")
|
||||
|
||||
self._access_token = result["access_token"]
|
||||
# MSAL may issue a new refresh token.
|
||||
if "refresh_token" in result:
|
||||
self._refresh_token = result["refresh_token"]
|
||||
self._credentials_info["refresh_token"] = result["refresh_token"]
|
||||
self._credentials_info["access_token"] = self._access_token
|
||||
|
||||
@property
|
||||
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||
"""Return updated credential dict if the access token was refreshed.
|
||||
|
||||
Returns ``None`` if no change was made.
|
||||
"""
|
||||
if self._access_token != self._original_access_token:
|
||||
return {**self._credentials_info, "access_token": self._access_token}
|
||||
return None
|
||||
|
||||
# ── HTTP helpers ───────────────────────────────────────────────────────
|
||||
|
||||
async def _get(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
url: str,
|
||||
params: dict[str, Any] | None = None,
|
||||
*,
|
||||
retry_on_401: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""GET *url* with auth; refresh token on 401 and retry once."""
|
||||
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
|
||||
logger.debug("ms_graph: 401 on %s — refreshing token", url)
|
||||
await self._refresh_access_token()
|
||||
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||
if resp.status_code == 429:
|
||||
raise RuntimeError("MS Graph rate limit hit (429). Try again later.")
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
# ── Public API ─────────────────────────────────────────────────────────
|
||||
|
||||
async def fetch_emails(
|
||||
self,
|
||||
filter_config: dict[str, Any] | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> list[EmailMessage]:
|
||||
"""Return up to ``_MAX_EMAILS`` Outlook messages matching *filter_config*.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filter_config:
|
||||
Optional dict with ``senders``, ``date_range``, ``folders`` keys.
|
||||
since:
|
||||
Hard lower-bound on email date (from last agent run).
|
||||
"""
|
||||
odata_filter = _build_email_filter(filter_config, since)
|
||||
params: dict[str, Any] = {
|
||||
"$top": 50,
|
||||
"$select": "id,subject,from,receivedDateTime,body,bodyPreview",
|
||||
"$orderby": "receivedDateTime desc",
|
||||
}
|
||||
if odata_filter:
|
||||
params["$filter"] = odata_filter
|
||||
|
||||
emails: list[EmailMessage] = []
|
||||
url = f"{_GRAPH_BASE}/me/messages"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
while url and len(emails) < _MAX_EMAILS:
|
||||
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||
for item in data.get("value", []):
|
||||
emails.append(self._parse_email(item))
|
||||
if len(emails) >= _MAX_EMAILS:
|
||||
break
|
||||
url = data.get("@odata.nextLink", "")
|
||||
params = {} # nextLink already contains encoded params.
|
||||
|
||||
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
|
||||
return emails
|
||||
|
||||
async def fetch_messages(
|
||||
self,
|
||||
filter_config: dict[str, Any] | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> list[ChatMessage]:
|
||||
"""Return up to ``_MAX_MESSAGES`` Teams messages matching *filter_config*.
|
||||
|
||||
Fetches from ``/me/chats/getAllMessages`` (personal + group chats).
|
||||
The ``filter_config.channels`` key is checked as a text-filter on
|
||||
the channel name post-fetch (the API doesn't support channel OData
|
||||
filter directly on ``getAllMessages``).
|
||||
"""
|
||||
cfg = filter_config or {}
|
||||
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
|
||||
params: dict[str, Any] = {"$top": 50}
|
||||
if since:
|
||||
params["$filter"] = f"createdDateTime ge {_odata_datetime(since)}"
|
||||
|
||||
messages: list[ChatMessage] = []
|
||||
url = f"{_GRAPH_BASE}/me/chats/getAllMessages"
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
while url and len(messages) < _MAX_MESSAGES:
|
||||
try:
|
||||
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
# getAllMessages requires specific licensing; degrade gracefully.
|
||||
if exc.response.status_code in (403, 404):
|
||||
logger.warning(
|
||||
"ms_graph: /me/chats/getAllMessages not available (%d) — "
|
||||
"check Teams license or permissions",
|
||||
exc.response.status_code,
|
||||
)
|
||||
break
|
||||
raise
|
||||
|
||||
for item in data.get("value", []):
|
||||
msg = self._parse_teams_message(item)
|
||||
if channel_filter and msg.channel:
|
||||
if not any(c in msg.channel.lower() for c in channel_filter):
|
||||
continue
|
||||
messages.append(msg)
|
||||
if len(messages) >= _MAX_MESSAGES:
|
||||
break
|
||||
url = data.get("@odata.nextLink", "")
|
||||
params = {}
|
||||
|
||||
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
|
||||
return messages
|
||||
|
||||
# ── Parsers ────────────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _parse_email(item: dict[str, Any]) -> EmailMessage:
|
||||
subject: str = item.get("subject", "(no subject)") or "(no subject)"
|
||||
sender_block = item.get("from", {}) or {}
|
||||
sender_addr = (
|
||||
(sender_block.get("emailAddress") or {}).get("address", "unknown")
|
||||
)
|
||||
date_str: str = item.get("receivedDateTime", "")
|
||||
try:
|
||||
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||
except Exception:
|
||||
date = datetime.now(timezone.utc)
|
||||
|
||||
body_block = item.get("body", {}) or {}
|
||||
content_type: str = body_block.get("contentType", "text")
|
||||
raw_body: str = body_block.get("content", "")
|
||||
if content_type == "html":
|
||||
body_text = _strip_html(raw_body)
|
||||
else:
|
||||
body_text = raw_body or item.get("bodyPreview", "")
|
||||
body_text = body_text[:_BODY_TRUNCATE]
|
||||
|
||||
return EmailMessage(
|
||||
id=item.get("id", ""),
|
||||
subject=subject,
|
||||
sender=sender_addr,
|
||||
body_text=body_text,
|
||||
date=date,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_teams_message(item: dict[str, Any]) -> ChatMessage:
|
||||
msg_id: str = item.get("id", "")
|
||||
sender_block = (item.get("from") or {}).get("user") or {}
|
||||
sender: str = sender_block.get("displayName", "unknown")
|
||||
channel: str | None = (item.get("channelIdentity") or {}).get("channelId")
|
||||
|
||||
date_str: str = item.get("createdDateTime", "")
|
||||
try:
|
||||
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||
except Exception:
|
||||
date = datetime.now(timezone.utc)
|
||||
|
||||
body_block = item.get("body", {}) or {}
|
||||
content_type: str = body_block.get("contentType", "text")
|
||||
raw_content: str = body_block.get("content", "")
|
||||
content = _strip_html(raw_content) if content_type == "html" else raw_content
|
||||
content = content[:_BODY_TRUNCATE]
|
||||
|
||||
return ChatMessage(
|
||||
id=msg_id,
|
||||
content=content,
|
||||
sender=sender,
|
||||
channel=channel,
|
||||
date=date,
|
||||
)
|
||||
72
app/main.py
72
app/main.py
@@ -1,72 +0,0 @@
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||
|
||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup: ensure agent tool modules are loaded.
|
||||
import app.agents # noqa: F401
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown: dispose SQLAlchemy connection pool
|
||||
from app.db import engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(
|
||||
title="Adiuva Cloud API",
|
||||
version="0.1.0",
|
||||
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||
redoc_url=None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
# Middleware stack (Starlette inserts at position 0, so last-added = outermost).
|
||||
# Request flow: TierRateLimit → Sanitizer → CORS → Router
|
||||
# Response flow: Router → CORS → Sanitizer → TierRateLimit
|
||||
app.add_middleware(SanitizerMiddleware)
|
||||
app.add_middleware(TierRateLimitMiddleware)
|
||||
|
||||
from app.api.routes import agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
|
||||
|
||||
app.include_router(auth.router, prefix="/api/v1")
|
||||
app.include_router(chat.router, prefix="/api/v1")
|
||||
app.include_router(storage.router, prefix="/api/v1")
|
||||
app.include_router(vectors.router, prefix="/api/v1")
|
||||
app.include_router(backup.router, prefix="/api/v1")
|
||||
app.include_router(plugins.router, prefix="/api/v1")
|
||||
app.include_router(billing.router, prefix="/api/v1")
|
||||
app.include_router(agents.router, prefix="/api/v1")
|
||||
app.include_router(device_ws.router, prefix="/api/v1")
|
||||
|
||||
@app.get("/api/v1/health", tags=["health"])
|
||||
async def health() -> dict:
|
||||
return {"status": "ok", "version": app.version}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
@@ -1,7 +0,0 @@
|
||||
"""Plugin marketplace package.
|
||||
|
||||
Three service classes introduced in Step 10:
|
||||
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
|
||||
- ``ReviewQueue`` — approval workflow + security checklist
|
||||
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
|
||||
"""
|
||||
@@ -1,212 +0,0 @@
|
||||
"""Plugin catalog registry backed by PostgreSQL.
|
||||
|
||||
Maintains the authoritative list of plugins, their review status, and
|
||||
aggregate install counts. All data is persisted in the ``plugins`` table.
|
||||
|
||||
Module-level singleton::
|
||||
|
||||
from app.marketplace.plugin_registry import registry
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import Plugin
|
||||
from app.schemas import PluginListResponse, PluginManifest
|
||||
|
||||
_PAGE_SIZE = 20
|
||||
|
||||
|
||||
def _plugin_to_manifest(p: Plugin) -> PluginManifest:
|
||||
"""Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``."""
|
||||
try:
|
||||
permissions = json.loads(p.permissions) if p.permissions else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
permissions = []
|
||||
return PluginManifest(
|
||||
id=p.id,
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
version=p.version,
|
||||
author=p.author_name,
|
||||
permissions=permissions,
|
||||
category=p.category,
|
||||
price_cents=p.price_cents,
|
||||
)
|
||||
|
||||
|
||||
class PluginRegistry:
|
||||
"""PostgreSQL-backed plugin catalog.
|
||||
|
||||
All methods accept an ``AsyncSession`` parameter so the calling route
|
||||
controls the session lifecycle.
|
||||
"""
|
||||
|
||||
# ── Queries ──────────────────────────────────────────────────────
|
||||
|
||||
async def list_plugins(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
category: str | None = None,
|
||||
query: str | None = None,
|
||||
page: int = 1,
|
||||
sort: Literal["rating", "installs", "newest"] = "newest",
|
||||
) -> PluginListResponse:
|
||||
"""Return a page of approved plugins, optionally filtered and sorted."""
|
||||
base = select(Plugin).where(Plugin.status == "approved")
|
||||
|
||||
if category:
|
||||
base = base.where(Plugin.category == category)
|
||||
if query:
|
||||
pattern = f"%{query}%"
|
||||
base = base.where(
|
||||
Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern)
|
||||
)
|
||||
|
||||
# Count
|
||||
count_q = select(func.count()).select_from(base.subquery())
|
||||
total = (await db.execute(count_q)).scalar_one()
|
||||
|
||||
# Sort
|
||||
if sort == "installs":
|
||||
base = base.order_by(Plugin.install_count.desc())
|
||||
elif sort == "rating":
|
||||
base = base.order_by(Plugin.avg_rating.desc())
|
||||
else: # newest
|
||||
base = base.order_by(Plugin.created_at.desc())
|
||||
|
||||
base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE)
|
||||
rows = (await db.execute(base)).scalars().all()
|
||||
|
||||
return PluginListResponse(
|
||||
plugins=[_plugin_to_manifest(r) for r in rows],
|
||||
total=total,
|
||||
page=page,
|
||||
)
|
||||
|
||||
async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None:
|
||||
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
|
||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
p = result.scalar_one_or_none()
|
||||
if p is None:
|
||||
return None
|
||||
return {
|
||||
"manifest": _plugin_to_manifest(p),
|
||||
"status": p.status,
|
||||
"install_count": p.install_count,
|
||||
"avg_rating": p.avg_rating,
|
||||
}
|
||||
|
||||
# ── Mutations ────────────────────────────────────────────────────
|
||||
|
||||
async def submit_plugin(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
manifest: PluginManifest,
|
||||
package_s3_key: str,
|
||||
) -> str:
|
||||
"""Add *manifest* to the catalog with ``status='pending_review'``.
|
||||
|
||||
Returns the plugin_id. If a plugin with the same id already exists
|
||||
it is overwritten (re-submission after rejection).
|
||||
"""
|
||||
plugin_id = manifest.id
|
||||
existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
row = existing.scalar_one_or_none()
|
||||
|
||||
if row is not None:
|
||||
row.name = manifest.name
|
||||
row.description = manifest.description
|
||||
row.version = manifest.version
|
||||
row.author_name = manifest.author
|
||||
row.category = manifest.category
|
||||
row.price_cents = manifest.price_cents
|
||||
row.permissions = json.dumps(manifest.permissions)
|
||||
row.status = "pending_review"
|
||||
row.s3_package_key = package_s3_key
|
||||
row.rejection_reason = None
|
||||
else:
|
||||
row = Plugin(
|
||||
id=plugin_id,
|
||||
name=manifest.name,
|
||||
description=manifest.description,
|
||||
version=manifest.version,
|
||||
author_name=manifest.author,
|
||||
category=manifest.category,
|
||||
price_cents=manifest.price_cents,
|
||||
permissions=json.dumps(manifest.permissions),
|
||||
status="pending_review",
|
||||
s3_package_key=package_s3_key,
|
||||
install_count=0,
|
||||
avg_rating=0.0,
|
||||
)
|
||||
db.add(row)
|
||||
await db.commit()
|
||||
return plugin_id
|
||||
|
||||
async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None:
|
||||
"""Set *plugin_id* status to ``'approved'``.
|
||||
|
||||
Raises ``KeyError`` if the plugin is not found.
|
||||
"""
|
||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||
row.status = "approved"
|
||||
row.rejection_reason = None
|
||||
await db.commit()
|
||||
|
||||
async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None:
|
||||
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
|
||||
|
||||
Raises ``KeyError`` if the plugin is not found.
|
||||
"""
|
||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||
row.status = "rejected"
|
||||
row.rejection_reason = reason
|
||||
await db.commit()
|
||||
|
||||
async def record_install(self, db: AsyncSession, plugin_id: str) -> None:
|
||||
"""Increment the install count for *plugin_id* (no-op if not found)."""
|
||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
row = result.scalar_one_or_none()
|
||||
if row is not None:
|
||||
row.install_count = row.install_count + 1
|
||||
await db.commit()
|
||||
|
||||
async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None:
|
||||
"""Decrement the install count for *plugin_id*, floored at 0."""
|
||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
row = result.scalar_one_or_none()
|
||||
if row is not None:
|
||||
row.install_count = max(0, row.install_count - 1)
|
||||
await db.commit()
|
||||
|
||||
# ── Internal helpers used by ReviewQueue ─────────────────────────
|
||||
|
||||
async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]:
|
||||
"""Return all entries with status='pending_review'."""
|
||||
result = await db.execute(
|
||||
select(Plugin).where(Plugin.status == "pending_review")
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"manifest": _plugin_to_manifest(r),
|
||||
"submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
registry = PluginRegistry()
|
||||
@@ -1,125 +0,0 @@
|
||||
"""Plugin review workflow backed by PostgreSQL.
|
||||
|
||||
Manages the approval queue for newly submitted plugins and enforces a
|
||||
security checklist before any plugin is made visible in the marketplace.
|
||||
|
||||
Module-level singleton::
|
||||
|
||||
from app.marketplace.plugin_review import review_queue
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.marketplace.plugin_registry import registry
|
||||
from app.models import PluginReview as PluginReviewModel
|
||||
from app.schemas import PluginManifest
|
||||
|
||||
# ── Security policy ───────────────────────────────────────────────────
|
||||
|
||||
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
|
||||
{
|
||||
"read:tasks",
|
||||
"write:tasks",
|
||||
"read:projects",
|
||||
"write:projects",
|
||||
"read:notes",
|
||||
"write:notes",
|
||||
"read:timelines",
|
||||
"write:timelines",
|
||||
"read:calendar",
|
||||
"write:calendar",
|
||||
}
|
||||
)
|
||||
|
||||
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
|
||||
|
||||
|
||||
def validate_manifest(manifest: PluginManifest) -> None:
|
||||
"""Enforce the plugin security checklist.
|
||||
|
||||
Raises:
|
||||
``ValueError`` on the first violation found. Callers should catch
|
||||
this and return HTTP 422 / reject the submission.
|
||||
|
||||
Checks:
|
||||
1. Plugin id matches ``^[a-z0-9-]+$``
|
||||
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
|
||||
3. No manifest field contains raw binary data
|
||||
"""
|
||||
if not _PLUGIN_ID_RE.match(manifest.id):
|
||||
raise ValueError(
|
||||
f"Invalid plugin id format: '{manifest.id}'. "
|
||||
"Only lowercase letters, digits, and hyphens are allowed."
|
||||
)
|
||||
|
||||
for perm in manifest.permissions:
|
||||
if perm not in ALLOWED_PERMISSIONS:
|
||||
raise ValueError(
|
||||
f"Unknown permission: '{perm}'. "
|
||||
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
|
||||
)
|
||||
|
||||
for field_name, value in manifest.model_dump().items():
|
||||
if isinstance(value, (bytes, bytearray)):
|
||||
raise ValueError(
|
||||
f"Binary content is not allowed in manifest field '{field_name}'."
|
||||
)
|
||||
|
||||
|
||||
class ReviewQueue:
|
||||
"""Approval queue for pending plugin submissions.
|
||||
|
||||
Delegates status changes to the shared ``PluginRegistry`` singleton.
|
||||
Review records are persisted in the ``plugin_reviews`` table.
|
||||
"""
|
||||
|
||||
async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]:
|
||||
"""Return all plugins currently awaiting review.
|
||||
|
||||
Each item is ``{plugin_id, manifest, submitted_at}``.
|
||||
"""
|
||||
entries = await registry.get_pending_entries(db)
|
||||
return [
|
||||
{
|
||||
"plugin_id": e["manifest"].id,
|
||||
"manifest": e["manifest"],
|
||||
"submitted_at": e["submitted_at"],
|
||||
}
|
||||
for e in entries
|
||||
]
|
||||
|
||||
async def submit_review(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
plugin_id: str,
|
||||
reviewer_id: str,
|
||||
decision: Literal["approved", "rejected"],
|
||||
notes: str = "",
|
||||
) -> None:
|
||||
"""Record a review decision and update the plugin's status.
|
||||
|
||||
Raises:
|
||||
``KeyError`` if *plugin_id* is not found in the registry.
|
||||
"""
|
||||
if decision == "approved":
|
||||
await registry.approve_plugin(db, plugin_id)
|
||||
else:
|
||||
await registry.reject_plugin(db, plugin_id, reason=notes)
|
||||
|
||||
review = PluginReviewModel(
|
||||
plugin_id=plugin_id,
|
||||
reviewer_id=reviewer_id,
|
||||
decision=decision,
|
||||
notes=notes,
|
||||
)
|
||||
db.add(review)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
review_queue = ReviewQueue()
|
||||
@@ -1,233 +0,0 @@
|
||||
"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL.
|
||||
|
||||
Records every plugin installation as a revenue event and facilitates
|
||||
70 % / 30 % payouts to developers via Stripe Connect. Data is persisted
|
||||
in the ``revenue_events`` table.
|
||||
|
||||
Module-level singleton::
|
||||
|
||||
from app.marketplace.revenue_share import revenue_share
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import stripe as stripe_lib
|
||||
from sqlalchemy import extract, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.marketplace.plugin_registry import registry
|
||||
from app.models import Plugin, RevenueEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Revenue split constants ───────────────────────────────────────────
|
||||
|
||||
DEVELOPER_SHARE: float = 0.70
|
||||
PLATFORM_SHARE: float = 0.30
|
||||
|
||||
|
||||
class RevenueShare:
|
||||
"""Records installation revenue events and coordinates developer payouts.
|
||||
|
||||
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
|
||||
is not configured, consistent with the rest of the billing layer.
|
||||
"""
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _stripe_configured() -> bool:
|
||||
return bool(settings.STRIPE_SECRET_KEY)
|
||||
|
||||
@staticmethod
|
||||
def _stripe() -> Any:
|
||||
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||
return stripe_lib
|
||||
|
||||
# ── Core operations ──────────────────────────────────────────────
|
||||
|
||||
async def record_install(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
plugin_id: str,
|
||||
user_id: str,
|
||||
amount_cents: int,
|
||||
) -> None:
|
||||
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
|
||||
|
||||
For free plugins (``amount_cents == 0``) no payment is initiated but
|
||||
the event is still recorded for analytics.
|
||||
|
||||
For paid plugins the developer receives 70 % via a Stripe Connect
|
||||
destination charge. If Stripe is not configured or the charge fails
|
||||
the installation still succeeds (the event is recorded and the install
|
||||
count is incremented) — a warning is logged for monitoring.
|
||||
"""
|
||||
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
|
||||
stripe_transfer_id: str | None = None
|
||||
|
||||
if amount_cents > 0 and self._stripe_configured():
|
||||
# Look up the plugin's author Stripe account from the DB
|
||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
plugin_row = result.scalar_one_or_none()
|
||||
developer_stripe_account: str | None = None
|
||||
if plugin_row and plugin_row.author_id:
|
||||
# Future: look up user.stripe_connect_account_id
|
||||
developer_stripe_account = None # no real account yet
|
||||
|
||||
if developer_stripe_account:
|
||||
try:
|
||||
s = self._stripe()
|
||||
transfer = s.Transfer.create(
|
||||
amount=developer_share_cents,
|
||||
currency="eur",
|
||||
destination=developer_stripe_account,
|
||||
description=f"Revenue share for plugin {plugin_id}",
|
||||
metadata={"plugin_id": plugin_id, "user_id": user_id},
|
||||
)
|
||||
stripe_transfer_id = transfer["id"]
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Stripe Connect transfer failed for plugin %s: %s",
|
||||
plugin_id,
|
||||
exc,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"No Stripe account on file for plugin %s developer; "
|
||||
"skipping transfer.",
|
||||
plugin_id,
|
||||
)
|
||||
|
||||
event = RevenueEvent(
|
||||
plugin_id=plugin_id,
|
||||
user_id=user_id,
|
||||
amount_cents=amount_cents,
|
||||
developer_share_cents=developer_share_cents,
|
||||
stripe_transfer_id=stripe_transfer_id,
|
||||
)
|
||||
db.add(event)
|
||||
await db.commit()
|
||||
|
||||
await registry.record_install(db, plugin_id)
|
||||
|
||||
async def get_earnings(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
developer_id: str,
|
||||
period: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return aggregated earnings for *developer_id*.
|
||||
|
||||
``period`` is an optional ``YYYY-MM`` string to restrict the window.
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"developer_id": str,
|
||||
"period": str | None,
|
||||
"total_installs": int,
|
||||
"total_revenue_cents": int,
|
||||
"developer_share_cents": int,
|
||||
}
|
||||
"""
|
||||
# Find plugin ids belonging to this developer (by author_name match)
|
||||
plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id)
|
||||
plugin_result = await db.execute(plugin_q)
|
||||
developer_plugin_ids = [row[0] for row in plugin_result.all()]
|
||||
|
||||
if not developer_plugin_ids:
|
||||
return {
|
||||
"developer_id": developer_id,
|
||||
"period": period,
|
||||
"total_installs": 0,
|
||||
"total_revenue_cents": 0,
|
||||
"developer_share_cents": 0,
|
||||
}
|
||||
|
||||
query = select(
|
||||
func.count().label("total_installs"),
|
||||
func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"),
|
||||
func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"),
|
||||
).where(RevenueEvent.plugin_id.in_(developer_plugin_ids))
|
||||
|
||||
if period:
|
||||
# Filter by YYYY-MM: extract year and month from created_at
|
||||
try:
|
||||
year, month = period.split("-")
|
||||
query = query.where(
|
||||
extract("year", RevenueEvent.created_at) == int(year),
|
||||
extract("month", RevenueEvent.created_at) == int(month),
|
||||
)
|
||||
except ValueError:
|
||||
pass # invalid period format — return all
|
||||
|
||||
result = await db.execute(query)
|
||||
row = result.one()
|
||||
|
||||
return {
|
||||
"developer_id": developer_id,
|
||||
"period": period,
|
||||
"total_installs": row.total_installs,
|
||||
"total_revenue_cents": row.total_revenue,
|
||||
"developer_share_cents": row.dev_share,
|
||||
}
|
||||
|
||||
async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None:
|
||||
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
|
||||
|
||||
Marks processed events with ``paid_at`` timestamp.
|
||||
Stubs gracefully when Stripe is not configured.
|
||||
"""
|
||||
try:
|
||||
year, month = period.split("-")
|
||||
year_int, month_int = int(year), int(month)
|
||||
except ValueError:
|
||||
logger.warning("Invalid period format: %s", period)
|
||||
return
|
||||
|
||||
result = await db.execute(
|
||||
select(RevenueEvent).where(
|
||||
RevenueEvent.plugin_id == plugin_id,
|
||||
RevenueEvent.paid_at.is_(None),
|
||||
extract("year", RevenueEvent.created_at) == year_int,
|
||||
extract("month", RevenueEvent.created_at) == month_int,
|
||||
)
|
||||
)
|
||||
unpaid = list(result.scalars().all())
|
||||
|
||||
total_dev_share = sum(e.developer_share_cents for e in unpaid)
|
||||
if total_dev_share <= 0 or not unpaid:
|
||||
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
|
||||
return
|
||||
|
||||
if self._stripe_configured():
|
||||
plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
plugin_row = plugin_result.scalar_one_or_none()
|
||||
developer_stripe_account: str | None = None # Future: fetch from DB
|
||||
if plugin_row and developer_stripe_account:
|
||||
try:
|
||||
s = self._stripe()
|
||||
s.Transfer.create(
|
||||
amount=total_dev_share,
|
||||
currency="eur",
|
||||
destination=developer_stripe_account,
|
||||
description=f"Payout for plugin {plugin_id} period {period}",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
|
||||
return
|
||||
|
||||
paid_ts = datetime.now(timezone.utc)
|
||||
for event in unpaid:
|
||||
event.paid_at = paid_ts
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
revenue_share = RevenueShare()
|
||||
476
app/models.py
476
app/models.py
@@ -1,476 +0,0 @@
|
||||
"""SQLAlchemy ORM models for all persistent tables.
|
||||
|
||||
Only auth, billing, storage metadata, and marketplace data live here.
|
||||
User content (notes, tasks, etc.) is NEVER persisted server-side —
|
||||
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
|
||||
|
||||
Table inventory:
|
||||
users — account credentials + tier
|
||||
refresh_tokens — hashed refresh token store
|
||||
subscriptions — Stripe subscription records
|
||||
storage_records — S3 blob metadata (no plaintext)
|
||||
backup_metadata — encrypted backup manifests
|
||||
plugins — marketplace plugin catalog
|
||||
plugin_installations — per-user install records
|
||||
plugin_reviews — admin review decisions
|
||||
revenue_events — Stripe Connect 70/30 split ledger
|
||||
memory_core — per-user persistent key/value preferences (encrypted)
|
||||
memory_associative — per-user semantic memory with embeddings (encrypted)
|
||||
memory_episodic — per-user session summaries (encrypted)
|
||||
memory_proactive — per-user behavioral patterns (encrypted)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Boolean,
|
||||
DateTime,
|
||||
Enum,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
JSON,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
Uuid,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db import Base
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _uuid() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
# ── Enum types ────────────────────────────────────────────────────────────
|
||||
|
||||
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
||||
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
|
||||
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
|
||||
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
|
||||
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
|
||||
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
|
||||
|
||||
|
||||
# ── Models ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
||||
# Used to encrypt/decrypt all memory rows for this user.
|
||||
encryption_key: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
refresh_tokens: Mapped[list[RefreshToken]] = relationship(
|
||||
back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
subscription: Mapped[Subscription | None] = relationship(
|
||||
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class RefreshToken(Base):
|
||||
__tablename__ = "refresh_tokens"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
||||
|
||||
|
||||
class Subscription(Base):
|
||||
__tablename__ = "subscriptions"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, unique=True, index=True
|
||||
)
|
||||
stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
|
||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||
status: Mapped[str] = mapped_column(String(50), nullable=False, default="free")
|
||||
current_period_end: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship(back_populates="subscription")
|
||||
|
||||
|
||||
class StorageRecord(Base):
|
||||
__tablename__ = "storage_records"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class BackupMetadata(Base):
|
||||
__tablename__ = "backup_metadata"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
class Plugin(Base):
|
||||
__tablename__ = "plugins"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
|
||||
# nullable until developer account system is built
|
||||
author_id: Mapped[str | None] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
|
||||
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
|
||||
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
|
||||
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
|
||||
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
submitted_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
installations: Mapped[list[PluginInstallation]] = relationship(
|
||||
back_populates="plugin", cascade="all, delete-orphan"
|
||||
)
|
||||
reviews: Mapped[list[PluginReview]] = relationship(
|
||||
back_populates="plugin", cascade="all, delete-orphan"
|
||||
)
|
||||
revenue_events: Mapped[list[RevenueEvent]] = relationship(
|
||||
back_populates="plugin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class PluginInstallation(Base):
|
||||
__tablename__ = "plugin_installations"
|
||||
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
plugin_id: Mapped[str] = mapped_column(
|
||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
installed_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
plugin: Mapped[Plugin] = relationship(back_populates="installations")
|
||||
|
||||
|
||||
class PluginReview(Base):
|
||||
__tablename__ = "plugin_reviews"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
plugin_id: Mapped[str] = mapped_column(
|
||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
reviewer_id: Mapped[str | None] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
|
||||
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
reviewed_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
|
||||
|
||||
|
||||
class RevenueEvent(Base):
|
||||
__tablename__ = "revenue_events"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
plugin_id: Mapped[str] = mapped_column(
|
||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
|
||||
|
||||
|
||||
class LocalAgentConfig(Base):
|
||||
__tablename__ = "local_agent_configs"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
device_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
run_logs: Mapped[list[AgentRunLog]] = relationship(
|
||||
back_populates="local_agent",
|
||||
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
|
||||
foreign_keys="AgentRunLog.agent_id",
|
||||
cascade="all, delete-orphan",
|
||||
overlaps="run_logs,cloud_agent",
|
||||
)
|
||||
|
||||
|
||||
class CloudAgentConfig(Base):
|
||||
__tablename__ = "cloud_agent_configs"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
provider: Mapped[str] = mapped_column(CloudProviderEnum, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||
oauth_token_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
filter_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
run_logs: Mapped[list[AgentRunLog]] = relationship(
|
||||
back_populates="cloud_agent",
|
||||
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
|
||||
foreign_keys="AgentRunLog.agent_id",
|
||||
cascade="all, delete-orphan",
|
||||
overlaps="run_logs,local_agent",
|
||||
)
|
||||
|
||||
|
||||
class AgentRunLog(Base):
|
||||
__tablename__ = "agent_run_logs"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
# Plain string — not a FK because it references either local_agent_configs or cloud_agent_configs
|
||||
# depending on agent_type. Query by (agent_id, agent_type) to locate the source config.
|
||||
agent_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
agent_type: Mapped[str] = mapped_column(AgentTypeEnum, nullable=False)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
|
||||
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||
started_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
local_agent: Mapped[LocalAgentConfig | None] = relationship(
|
||||
back_populates="run_logs",
|
||||
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
|
||||
foreign_keys="AgentRunLog.agent_id",
|
||||
overlaps="run_logs,cloud_agent",
|
||||
)
|
||||
cloud_agent: Mapped[CloudAgentConfig | None] = relationship(
|
||||
back_populates="run_logs",
|
||||
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
|
||||
foreign_keys="AgentRunLog.agent_id",
|
||||
overlaps="run_logs,local_agent",
|
||||
)
|
||||
|
||||
|
||||
# ── Memory models ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class MemoryCore(Base):
|
||||
"""Per-user persistent key/value preferences, encrypted at rest.
|
||||
|
||||
Examples: preferred_language, timezone, work_style.
|
||||
Decrypted in-memory only using User.encryption_key.
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_core"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
value_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class MemoryAssociative(Base):
|
||||
"""Per-user semantic memory: encrypted content + pgvector embedding for similarity search.
|
||||
|
||||
Production: ``embedding`` column is ``vector(1536)`` via pgvector.
|
||||
Tests (SQLite): stored as JSON list.
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_associative"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
# JSON-encoded float list in SQLite tests; vector(1536) in Postgres via migration.
|
||||
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class MemoryEpisodic(Base):
|
||||
"""Per-user session summaries, encrypted at rest.
|
||||
|
||||
One row per session interaction; used to recall recent conversations.
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_episodic"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
summary_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
session_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
class MemoryProactive(Base):
|
||||
"""Per-user inferred behavioral patterns, encrypted at rest.
|
||||
|
||||
Confidence in [0.0, 1.0]; only patterns above threshold are injected.
|
||||
Source: 'inferred' (from episodes) or 'explicit' (user-stated).
|
||||
"""
|
||||
|
||||
__tablename__ = "memory_proactive"
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, index=True,
|
||||
)
|
||||
pattern_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.5)
|
||||
source: Mapped[str] = mapped_column(String(50), nullable=False, default="inferred")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
321
app/schemas.py
321
app/schemas.py
@@ -1,321 +0,0 @@
|
||||
"""Pydantic schemas — API request/response contracts.
|
||||
|
||||
Mirrors the TypeScript types from the Electron app (src/shared/api-types.ts).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ── Billing ──────────────────────────────────────────────────────────
|
||||
|
||||
BillingTier = Literal["free", "pro", "power", "team"]
|
||||
|
||||
|
||||
# ── Auth ─────────────────────────────────────────────────────────────
|
||||
|
||||
class AuthTokens(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_at: int
|
||||
|
||||
|
||||
class UserProfile(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
name: str | None = None
|
||||
surname: str | None = None
|
||||
tier: BillingTier
|
||||
|
||||
|
||||
# ── Chat ─────────────────────────────────────────────────────────────
|
||||
|
||||
class ChatContext(BaseModel):
|
||||
user_profile: dict[str, Any] = Field(default_factory=dict)
|
||||
relevant_documents: list[str] = Field(default_factory=list)
|
||||
recent_tasks: list[dict[str, Any]] = Field(default_factory=list)
|
||||
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
context: ChatContext = Field(default_factory=ChatContext)
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
response: str
|
||||
|
||||
|
||||
# ── Backup ───────────────────────────────────────────────────────────
|
||||
|
||||
class BackupMetadata(BaseModel):
|
||||
version: int
|
||||
timestamp: int
|
||||
checksum: str
|
||||
chunk_count: int
|
||||
|
||||
|
||||
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
|
||||
|
||||
class StorageRecord(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
table: str
|
||||
blob: bytes
|
||||
checksum: str
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
|
||||
class StorageRecordCreate(BaseModel):
|
||||
table: str
|
||||
blob: bytes
|
||||
checksum: str
|
||||
|
||||
|
||||
class StorageRecordUpdate(BaseModel):
|
||||
blob: bytes
|
||||
checksum: str
|
||||
|
||||
|
||||
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
|
||||
|
||||
class VectorItem(BaseModel):
|
||||
id: str
|
||||
blob: bytes # encrypted vector + metadata — backend never decrypts
|
||||
checksum: str
|
||||
|
||||
|
||||
class VectorUpsertRequest(BaseModel):
|
||||
vectors: list[VectorItem]
|
||||
|
||||
|
||||
class VectorSearchRequest(BaseModel):
|
||||
query_blob: bytes # encrypted query — backend never decrypts
|
||||
top_k: int = 10
|
||||
|
||||
|
||||
class VectorSearchResult(BaseModel):
|
||||
id: str
|
||||
score: float
|
||||
blob: bytes
|
||||
|
||||
|
||||
class VectorSearchResponse(BaseModel):
|
||||
results: list[VectorSearchResult]
|
||||
|
||||
|
||||
# ── Plugin Marketplace ────────────────────────────────────────────────
|
||||
|
||||
class PluginManifest(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
version: str
|
||||
author: str
|
||||
permissions: list[str]
|
||||
category: str
|
||||
price_cents: int = 0
|
||||
|
||||
|
||||
class PluginListResponse(BaseModel):
|
||||
plugins: list[PluginManifest]
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
class PluginInstallRequest(BaseModel):
|
||||
plugin_id: str
|
||||
|
||||
|
||||
# ── WebSocket Frame Protocol ──────────────────────────────────────────
|
||||
|
||||
class WsFrameType(str, Enum):
|
||||
# ── v2 frame types (kept for backward compat) ──────────────────────
|
||||
chat_request = "chat_request"
|
||||
text_chunk = "text_chunk"
|
||||
tool_call = "tool_call"
|
||||
tool_result = "tool_result"
|
||||
final = "final"
|
||||
ping = "ping"
|
||||
device_hello = "device_hello"
|
||||
# ── v3 frame types ─────────────────────────────────────────────────
|
||||
home_request = "home_request"
|
||||
floating_request = "floating_request"
|
||||
stream_start = "stream_start"
|
||||
stream_text = "stream_text"
|
||||
stream_end = "stream_end"
|
||||
floating_domain = "floating_domain"
|
||||
data_request = "data_request"
|
||||
data_response = "data_response"
|
||||
mutation = "mutation"
|
||||
# ── v4 journey frame types ────────────────────────────────────────
|
||||
journey_start = "journey_start"
|
||||
journey_message = "journey_message"
|
||||
journey_reply = "journey_reply"
|
||||
|
||||
|
||||
class WsToolCall(BaseModel):
|
||||
"""Server → Client: requests a CRUD/vector operation on the local DB."""
|
||||
|
||||
type: Literal[WsFrameType.tool_call] = WsFrameType.tool_call
|
||||
id: str
|
||||
action: str
|
||||
table: str | None = None
|
||||
data: dict[str, Any] | None = None
|
||||
filters: dict[str, Any] | None = None
|
||||
vector: list[float] | None = None
|
||||
limit: int | None = None
|
||||
|
||||
|
||||
class WsToolResult(BaseModel):
|
||||
"""Client → Server: result of a CRUD/vector operation."""
|
||||
|
||||
type: Literal[WsFrameType.tool_result] = WsFrameType.tool_result
|
||||
id: str
|
||||
row: dict[str, Any] | None = None
|
||||
rows: list[dict[str, Any]] | None = None
|
||||
results: list[dict[str, Any]] | None = None
|
||||
deleted: bool | None = None
|
||||
ok: bool | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class WsTextChunk(BaseModel):
|
||||
"""Server → Client: incremental LLM response text."""
|
||||
|
||||
type: Literal[WsFrameType.text_chunk] = WsFrameType.text_chunk
|
||||
text: str
|
||||
|
||||
|
||||
class WsFinal(BaseModel):
|
||||
"""Server → Client: signals end of response with the complete text."""
|
||||
|
||||
type: Literal[WsFrameType.final] = WsFrameType.final
|
||||
response: str
|
||||
|
||||
|
||||
# ── WebSocket Agent Frame Protocol ────────────────────────────────────
|
||||
|
||||
class WsDeviceHello(BaseModel):
|
||||
"""Client → Server: device identification on WS connect."""
|
||||
|
||||
type: Literal[WsFrameType.device_hello] = WsFrameType.device_hello
|
||||
device_id: str
|
||||
agent_ids: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
|
||||
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
||||
|
||||
class WsFloatingScope(BaseModel):
|
||||
"""Scope for a floating request — narrows the agent to a specific entity."""
|
||||
|
||||
type: Literal["task", "project", "note", "timeline"]
|
||||
id: str | None = None
|
||||
|
||||
|
||||
class WsHomeRequest(BaseModel):
|
||||
"""Client → Server: Home chat message."""
|
||||
|
||||
type: Literal[WsFrameType.home_request] = WsFrameType.home_request
|
||||
message: str
|
||||
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WsFloatingRequest(BaseModel):
|
||||
"""Client → Server: Floating chat message scoped to an entity."""
|
||||
|
||||
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
|
||||
message: str
|
||||
scope: WsFloatingScope
|
||||
|
||||
|
||||
class WsStreamStart(BaseModel):
|
||||
"""Server → Client: signals start of a streaming response."""
|
||||
|
||||
type: Literal[WsFrameType.stream_start] = WsFrameType.stream_start
|
||||
request_id: str
|
||||
|
||||
|
||||
class WsStreamText(BaseModel):
|
||||
"""Server → Client: streamed text token."""
|
||||
|
||||
type: Literal[WsFrameType.stream_text] = WsFrameType.stream_text
|
||||
request_id: str
|
||||
chunk: str
|
||||
|
||||
|
||||
class WsStreamEnd(BaseModel):
|
||||
"""Server → Client: signals end of a streaming response."""
|
||||
|
||||
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||
request_id: str
|
||||
|
||||
|
||||
class WsDomain(BaseModel):
|
||||
"""Structured floating domain payload for UI routing decisions."""
|
||||
|
||||
type: Literal["task", "timeline", "project", "node"]
|
||||
id: str | None = None
|
||||
section: Literal["task", "timeline", "note"] | None = None
|
||||
|
||||
|
||||
class WsFloatingDomain(BaseModel):
|
||||
"""Server → Client: domain determined for a floating request."""
|
||||
|
||||
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
||||
request_id: str
|
||||
domain: WsDomain
|
||||
|
||||
|
||||
# ── Agent Catalog ─────────────────────────────────────────────────────
|
||||
|
||||
class AgentCatalogItem(BaseModel):
|
||||
type: str
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
class AgentCreationCheckRequest(BaseModel):
|
||||
active_agents: int = Field(ge=0, default=0)
|
||||
|
||||
|
||||
class AgentCreationCheckResponse(BaseModel):
|
||||
allowed: bool
|
||||
tier: BillingTier
|
||||
active_agents: int
|
||||
limit: int
|
||||
|
||||
|
||||
class AgentTriggerRequest(BaseModel):
|
||||
directory: str = Field(min_length=1)
|
||||
device_id: str = Field(default="")
|
||||
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
|
||||
what_to_extract: list[str] = Field(min_length=1)
|
||||
actions_by_type: dict[str, list[str]] | None = None
|
||||
batch_interval: str = Field(min_length=1)
|
||||
custom_agent_prompt: str = Field(min_length=1)
|
||||
active_agents: int = Field(ge=0, default=0)
|
||||
|
||||
|
||||
# ── Agent Run Log ─────────────────────────────────────────────────────
|
||||
|
||||
class AgentRunLogResponse(BaseModel):
|
||||
id: str
|
||||
agent_id: str
|
||||
agent_type: Literal["local", "cloud"]
|
||||
status: Literal["running", "success", "error", "partial"]
|
||||
items_processed: int
|
||||
items_created: int
|
||||
errors: list[str]
|
||||
started_at: int
|
||||
completed_at: int | None
|
||||
|
||||
|
||||
# ── Chatbot Journey ───────────────────────────────────────────────────
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Cloud storage layer — E2E encrypted blobs and vectors."""
|
||||
@@ -1,106 +0,0 @@
|
||||
"""S3-backed store for E2E-encrypted blobs.
|
||||
|
||||
Keys are structured as ``{user_id}/{table}/{record_id}``.
|
||||
The backend never inspects blob content — it stores and retrieves opaque bytes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
class BlobStore:
|
||||
"""Thin wrapper around boto3 S3.
|
||||
|
||||
All blobs must be E2E encrypted by the client before upload.
|
||||
The backend adds SSE-S3 as an extra layer of at-rest encryption
|
||||
but cannot decrypt the inner client-side payload.
|
||||
"""
|
||||
|
||||
def _client(self) -> Any:
|
||||
kwargs: dict[str, Any] = {
|
||||
"region_name": settings.S3_REGION,
|
||||
"aws_access_key_id": settings.AWS_ACCESS_KEY_ID,
|
||||
"aws_secret_access_key": settings.AWS_SECRET_ACCESS_KEY,
|
||||
}
|
||||
if settings.S3_ENDPOINT_URL and isinstance(settings.S3_ENDPOINT_URL, str):
|
||||
kwargs["endpoint_url"] = settings.S3_ENDPOINT_URL
|
||||
return boto3.client("s3", **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _key(user_id: str, table: str, record_id: str) -> str:
|
||||
return f"{user_id}/{table}/{record_id}"
|
||||
|
||||
async def upload(
|
||||
self,
|
||||
user_id: str,
|
||||
table: str,
|
||||
record_id: str,
|
||||
blob: bytes,
|
||||
checksum: str,
|
||||
) -> str:
|
||||
"""Store *blob* in S3 and return the S3 key.
|
||||
|
||||
Args:
|
||||
user_id: Owner of the blob (used as key prefix).
|
||||
table: Logical table name (e.g. ``"tasks"``).
|
||||
record_id: Record UUID.
|
||||
blob: Raw bytes (pre-encrypted by client).
|
||||
checksum: SHA-256 hex digest supplied by the client; stored as
|
||||
object metadata for download-time verification.
|
||||
|
||||
Returns:
|
||||
The S3 key under which the blob was stored.
|
||||
"""
|
||||
key = self._key(user_id, table, record_id)
|
||||
self._client().put_object(
|
||||
Bucket=settings.S3_BUCKET,
|
||||
Key=key,
|
||||
Body=blob,
|
||||
ServerSideEncryption="AES256", # SSE-S3 at rest
|
||||
Metadata={"checksum": checksum},
|
||||
)
|
||||
return key
|
||||
|
||||
async def download(self, user_id: str, s3_key: str) -> bytes:
|
||||
"""Retrieve the blob stored at *s3_key*.
|
||||
|
||||
*user_id* is retained in the signature so higher-level code can
|
||||
enforce ownership without re-parsing the key.
|
||||
|
||||
Raises:
|
||||
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
|
||||
object does not exist.
|
||||
"""
|
||||
response = self._client().get_object(
|
||||
Bucket=settings.S3_BUCKET,
|
||||
Key=s3_key,
|
||||
)
|
||||
return response["Body"].read()
|
||||
|
||||
async def delete(self, user_id: str, s3_key: str) -> None:
|
||||
"""Delete the object at *s3_key*.
|
||||
|
||||
S3 ``delete_object`` is idempotent — it succeeds even if the key does
|
||||
not exist.
|
||||
"""
|
||||
self._client().delete_object(
|
||||
Bucket=settings.S3_BUCKET,
|
||||
Key=s3_key,
|
||||
)
|
||||
|
||||
async def list_keys(self, user_id: str, table: str) -> list[str]:
|
||||
"""Return all S3 keys for a given user + table combination.
|
||||
|
||||
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
|
||||
"""
|
||||
prefix = f"{user_id}/{table}/"
|
||||
response = self._client().list_objects_v2(
|
||||
Bucket=settings.S3_BUCKET,
|
||||
Prefix=prefix,
|
||||
)
|
||||
return [obj["Key"] for obj in response.get("Contents", [])]
|
||||
@@ -1,32 +0,0 @@
|
||||
"""Integrity verification only — the backend NEVER decrypts user data."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def verify_checksum(blob: bytes, checksum: str) -> bool:
|
||||
"""Return ``True`` if SHA-256(blob) matches *checksum*.
|
||||
|
||||
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
|
||||
timing-based side-channel attacks.
|
||||
"""
|
||||
computed = hashlib.sha256(blob).hexdigest()
|
||||
return hmac.compare_digest(computed, checksum)
|
||||
|
||||
|
||||
def reject_if_tampered(blob: bytes, checksum: str) -> None:
|
||||
"""Raise ``HTTP 400`` if the blob does not match its checksum.
|
||||
|
||||
Call this before storing or forwarding any client-provided blob.
|
||||
The backend never holds decryption keys — this check only verifies
|
||||
that the opaque bytes arrived intact.
|
||||
"""
|
||||
if not verify_checksum(blob, checksum):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Checksum mismatch: blob integrity check failed",
|
||||
)
|
||||
@@ -1,205 +0,0 @@
|
||||
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
|
||||
|
||||
Vectors are pre-encrypted blobs from the client. The backend stores them
|
||||
alongside a deterministic 32-dim float representation derived from the blob's
|
||||
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
|
||||
is a known trade-off documented in the backend plan.
|
||||
|
||||
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
|
||||
``user_id`` payload field on a shared collection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
from typing import Any
|
||||
|
||||
from pinecone import Pinecone
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.schemas import VectorItem, VectorSearchResult
|
||||
|
||||
_QDRANT_COLLECTION = "adiuva_vectors"
|
||||
|
||||
|
||||
def _blob_to_vector(blob: bytes) -> list[float]:
|
||||
"""Derive a 32-dim float vector from *blob* for storage purposes only.
|
||||
|
||||
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
|
||||
normalises each byte to the range [-1.0, 1.0]. This vector carries no
|
||||
semantic meaning on encrypted data.
|
||||
"""
|
||||
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""Thin wrapper around Pinecone or Qdrant.
|
||||
|
||||
The backend to use is selected at runtime:
|
||||
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
|
||||
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
|
||||
"""
|
||||
|
||||
def _use_pinecone(self) -> bool:
|
||||
return bool(settings.PINECONE_API_KEY)
|
||||
|
||||
# ── Pinecone helpers ──────────────────────────────────────────────
|
||||
|
||||
def _pinecone_index(self) -> Any:
|
||||
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
|
||||
return pc.Index(settings.PINECONE_INDEX)
|
||||
|
||||
# ── Qdrant helpers ────────────────────────────────────────────────
|
||||
|
||||
def _qdrant_client(self) -> Any:
|
||||
return QdrantClient(
|
||||
url=settings.QDRANT_URL,
|
||||
api_key=settings.QDRANT_API_KEY or None,
|
||||
)
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────
|
||||
|
||||
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||
"""Store encrypted vectors in the backend.
|
||||
|
||||
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
|
||||
so it can be returned verbatim during search.
|
||||
|
||||
Args:
|
||||
user_id: Used as Pinecone namespace or Qdrant payload field.
|
||||
vectors: List of encrypted vector items from the client.
|
||||
"""
|
||||
if self._use_pinecone():
|
||||
await self._pinecone_upsert(user_id, vectors)
|
||||
else:
|
||||
await self._qdrant_upsert(user_id, vectors)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
user_id: str,
|
||||
query_blob: bytes,
|
||||
top_k: int,
|
||||
) -> list[VectorSearchResult]:
|
||||
"""Query the vector store and return encrypted result blobs.
|
||||
|
||||
The query vector is derived from *query_blob* using the same
|
||||
deterministic mapping as upsert.
|
||||
|
||||
Args:
|
||||
user_id: Scopes the search to this user's namespace.
|
||||
query_blob: Encrypted query from the client.
|
||||
top_k: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
|
||||
"""
|
||||
if self._use_pinecone():
|
||||
return await self._pinecone_search(user_id, query_blob, top_k)
|
||||
return await self._qdrant_search(user_id, query_blob, top_k)
|
||||
|
||||
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||
"""Remove vectors by ID, scoped to *user_id*.
|
||||
|
||||
Args:
|
||||
user_id: Namespace / payload filter to prevent cross-user deletion.
|
||||
vector_ids: List of vector IDs to remove.
|
||||
"""
|
||||
if self._use_pinecone():
|
||||
await self._pinecone_delete(user_id, vector_ids)
|
||||
else:
|
||||
await self._qdrant_delete(user_id, vector_ids)
|
||||
|
||||
# ── Pinecone implementation ───────────────────────────────────────
|
||||
|
||||
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||
index = self._pinecone_index()
|
||||
records = [
|
||||
{
|
||||
"id": v.id,
|
||||
"values": _blob_to_vector(v.blob),
|
||||
"metadata": {
|
||||
"blob": base64.b64encode(v.blob).decode(),
|
||||
"checksum": v.checksum,
|
||||
"user_id": user_id,
|
||||
},
|
||||
}
|
||||
for v in vectors
|
||||
]
|
||||
index.upsert(vectors=records, namespace=user_id)
|
||||
|
||||
async def _pinecone_search(
|
||||
self, user_id: str, query_blob: bytes, top_k: int
|
||||
) -> list[VectorSearchResult]:
|
||||
index = self._pinecone_index()
|
||||
query_vector = _blob_to_vector(query_blob)
|
||||
response = index.query(
|
||||
vector=query_vector,
|
||||
top_k=top_k,
|
||||
namespace=user_id,
|
||||
include_metadata=True,
|
||||
)
|
||||
results: list[VectorSearchResult] = []
|
||||
for match in response.get("matches", []):
|
||||
blob_bytes = base64.b64decode(match["metadata"]["blob"])
|
||||
results.append(
|
||||
VectorSearchResult(
|
||||
id=match["id"],
|
||||
score=match["score"],
|
||||
blob=blob_bytes,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||
index = self._pinecone_index()
|
||||
index.delete(ids=vector_ids, namespace=user_id)
|
||||
|
||||
# ── Qdrant implementation ─────────────────────────────────────────
|
||||
|
||||
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||
client = self._qdrant_client()
|
||||
points = [
|
||||
PointStruct(
|
||||
id=v.id,
|
||||
vector=_blob_to_vector(v.blob),
|
||||
payload={
|
||||
"blob": base64.b64encode(v.blob).decode(),
|
||||
"checksum": v.checksum,
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
for v in vectors
|
||||
]
|
||||
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
|
||||
|
||||
async def _qdrant_search(
|
||||
self, user_id: str, query_blob: bytes, top_k: int
|
||||
) -> list[VectorSearchResult]:
|
||||
client = self._qdrant_client()
|
||||
query_vector = _blob_to_vector(query_blob)
|
||||
hits = client.search(
|
||||
collection_name=_QDRANT_COLLECTION,
|
||||
query_vector=query_vector,
|
||||
query_filter=Filter(
|
||||
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
|
||||
),
|
||||
limit=top_k,
|
||||
)
|
||||
return [
|
||||
VectorSearchResult(
|
||||
id=str(hit.id),
|
||||
score=hit.score,
|
||||
blob=base64.b64decode(hit.payload["blob"]),
|
||||
)
|
||||
for hit in hits
|
||||
]
|
||||
|
||||
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||
client = self._qdrant_client()
|
||||
client.delete(
|
||||
collection_name=_QDRANT_COLLECTION,
|
||||
points_selector=PointIdsList(points=vector_ids),
|
||||
)
|
||||
@@ -1,37 +0,0 @@
|
||||
fastapi>=0.115.0
|
||||
uvicorn[standard]>=0.34.0
|
||||
gunicorn>=22.0.0
|
||||
langchain>=0.3.0
|
||||
langchain-openai>=0.3.0
|
||||
langchain-litellm>=0.1.0
|
||||
litellm>=1.50.0
|
||||
pydantic>=2.10.0
|
||||
pydantic-settings>=2.7.0
|
||||
python-jose[cryptography]>=3.3.0
|
||||
stripe>=11.0.0
|
||||
boto3>=1.35.0
|
||||
slowapi>=0.1.9
|
||||
sqlalchemy>=2.0.0
|
||||
asyncpg>=0.30.0
|
||||
alembic>=1.14.0
|
||||
bcrypt>=4.2.0
|
||||
python-dotenv>=1.0.0
|
||||
httpx>=0.28.0
|
||||
websockets>=14.0
|
||||
psycopg2-binary>=2.9.0
|
||||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.24.0
|
||||
aiosqlite>=0.20.0
|
||||
moto[s3]>=5.0.0
|
||||
pinecone>=5.0.0
|
||||
qdrant-client>=1.7.0
|
||||
croniter>=3.0.0
|
||||
google-api-python-client>=2.130.0
|
||||
google-auth>=2.29.0
|
||||
google-auth-oauthlib>=1.2.0
|
||||
google-auth-httplib2>=0.2.0
|
||||
msal>=1.28.0
|
||||
cryptography>=42.0.0
|
||||
redis>=5.0.0
|
||||
langfuse>=3.0.0
|
||||
ruff>=0.8.0
|
||||
Reference in New Issue
Block a user