refactor: remove monolith app/, Dockerfile, requirements.txt

All business logic has been extracted into microservices:
  - services/auth/       (Step 1)
  - services/ws-gateway/ (Step 2)
  - services/chat/       (Step 2)
  - services/batch-agent/ (Step 3)
  - services/billing/    (Step 4)

Shared code lives in shared/.
Migrations remain in alembic/.
Tests in tests/ will need updating to target individual services.
This commit is contained in:
Roberto Musso
2026-04-06 23:44:12 +02:00
parent 7f6ea29525
commit 2b7d302ef2
55 changed files with 0 additions and 9159 deletions

View File

@@ -1,39 +0,0 @@
# ── builder ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /build
COPY requirements.txt .
RUN pip install --upgrade pip && \
pip install --no-cache-dir --prefix=/install -r requirements.txt
# ── runtime ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS runtime
# Non-root user
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
WORKDIR /app
# Copy installed packages from builder
COPY --from=builder /install /usr/local
# Copy application source
COPY app/ app/
# Copy Alembic migration files
COPY alembic/ alembic/
COPY alembic.ini .
# Ensure appuser owns the working directory
RUN chown -R appuser:appgroup /app
USER appuser
EXPOSE 8000
CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \
"--bind", "0.0.0.0:8000", \
"--workers", "4", \
"--timeout", "120"]

View File

View File

@@ -1,5 +0,0 @@
"""Expose tool modules used by deep orchestrator-worker graphs."""
from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent
__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"]

View File

@@ -1,85 +0,0 @@
"""Filesystem agent — tools for reading local directories and files on Electron.
These tools delegate to the Electron client via ``execute_on_client()`` using
the same WS tool-call round-trip pattern as CRUD tools. The Electron app
handles actual disk I/O and responds with ``tool_result`` frames.
"""
from __future__ import annotations
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
@tool
async def list_directory(path: str) -> str:
"""List files and folders in a local directory on the user's device.
Returns a formatted listing of entries with name, type (file/directory),
and full path.
"""
result = await execute_on_client(
action="list_directory",
data={"path": path},
)
entries: list[dict[str, Any]] = result.get("entries", [])
if not entries:
return f"Directory '{path}' is empty or does not exist."
lines: list[str] = []
for entry in entries:
entry_type = entry.get("type", "unknown")
entry_name = entry.get("name", "")
entry_path = entry.get("path", "")
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
@tool
async def read_file_content(path: str) -> str:
"""Read the text content of a local file on the user's device.
Returns the file content as a string. Large files may be truncated
by the Electron client.
"""
result = await execute_on_client(
action="read_file_content",
data={"path": path},
)
content: str = result.get("content", "")
if not content:
return f"File '{path}' is empty or could not be read."
return content
@tool
async def get_file_metadata(path: str) -> str:
"""Get metadata for a local file: size, creation date, modification date, extension.
Returns a formatted summary of the file's metadata.
"""
result = await execute_on_client(
action="get_file_metadata",
data={"path": path},
)
size = result.get("size", "unknown")
created = result.get("createdAt", "unknown")
modified = result.get("modifiedAt", "unknown")
extension = result.get("extension", "unknown")
name = result.get("name", path)
return (
f"File: {name}\n"
f" Extension: {extension}\n"
f" Size: {size} bytes\n"
f" Created: {created}\n"
f" Modified: {modified}"
)
FILESYSTEM_TOOLS: list[Any] = [
list_directory,
read_file_content,
get_file_metadata,
]

View File

@@ -1,139 +0,0 @@
"""Note agent — Markdown note management (list, get, create, update, delete)."""
from __future__ import annotations
import re
from typing import Any
from langchain_core.tools import tool
from app.core.llm import embed
from app.core.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
NOTE_SYSTEM_PROMPT = (
"You are a note-taking assistant. You help users create, retrieve, update,\n"
"and delete Markdown notes in their workspace.\n\n"
"Rules:\n"
" - content is always Markdown; preserve formatting when updating\n"
" - project_id is optional; link a note to a project when mentioned\n"
" - When updating, call get_note first if you need to read existing content\n"
" before appending or replacing sections\n"
" - list_notes without project_id returns all notes; scope with project_id\n"
" when the user is working within a specific project\n"
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
" - Do not fabricate note content — reflect what the user provides or what\n"
" is already in the note (retrieved via get_note)."
)
@tool
async def list_notes(project_id: str = "") -> str:
"""List notes, optionally scoped to a project by project_id."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="notes",
filters={"projectId": normalized_project_id or None},
)
rows = result.get("rows", [])
if not rows:
return "No notes found."
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
@tool
async def get_note(note_id: str) -> str:
"""Fetch a single note by its UUID to read its full Markdown content."""
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
row = result.get("row")
if not row:
return f"Note {note_id} not found."
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
@tool
async def create_note(
title: str,
content: str,
project_id: str = "",
) -> str:
"""Create a new note.
title: note heading (required)
content: Markdown body text (required)
project_id: optional UUID linking this note to a project
"""
result = await execute_on_client(
action="insert",
table="notes",
data={
"title": title,
"content": content,
"projectId": project_id or None,
},
)
row = result["row"]
# Index the note content in the vector store.
vector = await embed(content)
await execute_on_client(
action="vector_upsert",
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note created: '{row['title']}' (id: {row['id']})."
@tool
async def update_note(
note_id: str,
title: str = "",
content: str = "",
) -> str:
"""Update an existing note. Only pass fields that should change.
note_id: UUID of the note (required)
If you need to preserve existing content, call get_note first.
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if content:
updates["content"] = content
result = await execute_on_client(
action="update",
table="notes",
data={"id": note_id, "updates": updates},
)
row = result["row"]
# Re-index if content changed.
if content:
vector = await embed(content)
await execute_on_client(
action="vector_upsert",
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note updated: '{row['title']}' (id: {row['id']})."
@tool
async def delete_note(note_id: str) -> str:
"""Delete a note permanently by its UUID."""
await execute_on_client(action="delete", table="notes", data={"id": note_id})
return f"Note {note_id} deleted."
NOTE_TOOLS: list[Any] = [
list_notes,
get_note,
create_note,
update_note,
delete_note,
]

View File

@@ -1,143 +0,0 @@
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
from __future__ import annotations
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
PROJECT_SYSTEM_PROMPT = (
"You are a project management assistant. You help users create, find,\n"
"update, and archive projects in their workspace.\n\n"
"Rules:\n"
" - status must be one of: active, archived\n"
" - client_id is optional; link to a client only when explicitly mentioned\n"
" - ai_summary is populated only when the user asks for a project summary;\n"
" derive it from context data — do not fabricate content\n"
" - Use list_projects for scoped queries; list_all_projects only when the\n"
" user wants a complete cross-client view including archived projects\n"
" - get_project requires a project UUID; resolve the ID first by calling\n"
" list_projects if you only have a project name\n"
" - Prefer archiving (update_project status=archived) over deletion;\n"
" only call delete_project when the user explicitly confirms deletion."
)
@tool
async def list_projects(
client_id: str = "",
include_archived: int = 0,
) -> str:
"""List projects, optionally filtered by client_id.
include_archived: 1 to include archived projects, 0 for active only (default).
"""
result = await execute_on_client(
action="select",
table="projects",
filters={
"clientId": client_id or None,
"includeArchived": bool(include_archived),
},
)
rows = result.get("rows", [])
if not rows:
return "No projects found."
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
return f"Found {len(rows)} project(s):\n" + "\n".join(lines)
@tool
async def list_all_projects() -> str:
"""List every project regardless of client or status.
Use only when the user wants a complete cross-client overview.
"""
result = await execute_on_client(action="select", table="projects")
rows = result.get("rows", [])
if not rows:
return "No projects found."
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
return f"All projects ({len(rows)}):\n" + "\n".join(lines)
@tool
async def get_project(project_id: str) -> str:
"""Fetch a single project by its UUID."""
result = await execute_on_client(action="get", table="projects", data={"id": project_id})
row = result.get("row")
if not row:
return f"Project {project_id} not found."
return (
f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, "
f"clientId: {row.get('clientId', 'none')})"
)
@tool
async def create_project(
name: str,
client_id: str = "",
) -> str:
"""Create a new project.
name: human-readable project name (required)
client_id: optional UUID of the owning client
"""
result = await execute_on_client(
action="insert",
table="projects",
data={"name": name, "clientId": client_id or None},
)
row = result["row"]
return f"Project created: '{row['name']}' (id: {row['id']})"
@tool
async def update_project(
project_id: str,
name: str = "",
client_id: str = "",
status: str = "",
ai_summary: str = "",
) -> str:
"""Update a project. Only pass fields that should change.
project_id: UUID of the project (required)
status: active | archived
ai_summary: AI-generated summary text (populate only when explicitly requested)
"""
updates: dict[str, Any] = {}
if name:
updates["name"] = name
if client_id:
updates["clientId"] = client_id
if status:
updates["status"] = status
if ai_summary:
updates["aiSummary"] = ai_summary
result = await execute_on_client(
action="update",
table="projects",
data={"id": project_id, "updates": updates},
)
row = result["row"]
return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})"
@tool
async def delete_project(project_id: str) -> str:
"""Permanently delete a project and orphan its tasks.
IMPORTANT: prefer update_project(status='archived') unless the user
has explicitly confirmed they want permanent deletion.
"""
await execute_on_client(action="delete", table="projects", data={"id": project_id})
return f"Project {project_id} permanently deleted."
PROJECT_TOOLS: list[Any] = [
list_projects,
list_all_projects,
get_project,
create_project,
update_project,
delete_project,
]

View File

@@ -1,238 +0,0 @@
"""Task agent — full CRUD for tasks and task comments."""
from __future__ import annotations
from datetime import datetime, timezone
import re
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
TASK_SYSTEM_PROMPT = (
"You are a task management assistant for a project workspace.\n"
"You create, update, list, and track tasks and their comments.\n\n"
"Rules:\n"
" - status must be one of: todo, in_progress, done\n"
" - priority must be one of: high, medium, low\n"
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
" - project_id is optional; link to a project when the user mentions one\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
" did not explicitly request; 0 otherwise\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n"
" - Use list_tasks_due_today for 'what's due today' queries\n"
" - For update_task, use -1 for integer fields you do not want to change\n"
" - Always confirm the action in plain, user-friendly language."
)
# ── Task tools ────────────────────────────────────────────────────────
@tool
async def list_tasks(
project_id: str = "",
status: str = "",
search: str = "",
order_by: str = "",
) -> str:
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
a search string, or an order_by field name (dueDate|priority|createdAt)."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="tasks",
filters={
"projectId": normalized_project_id or None,
"status": status or None,
"search": search or None,
"orderBy": order_by or None,
},
)
rows = result.get("rows", [])
if not rows:
return "No tasks found matching the given filters."
lines = [
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
for r in rows
]
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
@tool
async def create_task(
title: str,
description: str = "",
status: str = "todo",
priority: str = "medium",
assignees: str = "[]",
due_date: int = 0,
project_id: str = "",
is_ai_suggested: int = 0,
) -> str:
"""Create a new task.
title: task title (required)
description: optional details
status: todo | in_progress | done (default: todo)
priority: high | medium | low (default: medium)
assignees: JSON-encoded array of assignee names, e.g. '["Alice"]'
due_date: Unix timestamp in milliseconds; 0 means no due date
project_id: optional UUID of the parent project
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
"""
result = await execute_on_client(
action="insert",
table="tasks",
data={
"title": title,
"description": description or None,
"status": status,
"priority": priority,
"assignee": assignees,
"dueDate": due_date or None,
"projectId": project_id or None,
"isAiSuggested": is_ai_suggested,
},
)
row = result["row"]
return (
f"Task created: '{row['title']}' "
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
)
@tool
async def update_task(
task_id: str,
title: str = "",
description: str = "",
status: str = "",
priority: str = "",
assignees: str = "",
due_date: int = -1,
project_id: str = "",
) -> str:
"""Update fields on an existing task. Only pass fields you want to change.
task_id: the task's UUID (required)
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if description:
updates["description"] = description
if status:
updates["status"] = status
if priority:
updates["priority"] = priority
if assignees:
updates["assignee"] = assignees
if due_date != -1:
updates["dueDate"] = due_date or None
if project_id:
updates["projectId"] = project_id
result = await execute_on_client(
action="update",
table="tasks",
data={"id": task_id, "updates": updates},
)
row = result["row"]
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
@tool
async def delete_task(task_id: str) -> str:
"""Delete a task permanently by its UUID."""
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
return f"Task {task_id} deleted."
@tool
async def list_tasks_due_today() -> str:
"""List all tasks whose due date falls on today's date."""
now = datetime.now(tz=timezone.utc)
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
end_ms = start_ms + 86_400_000 - 1 # last ms of today
result = await execute_on_client(
action="select",
table="tasks",
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
)
rows = result.get("rows", [])
if not rows:
return "No tasks are due today."
lines = [
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
for r in rows
]
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
# ── Task comment tools ────────────────────────────────────────────────
@tool
async def list_task_comments(task_id: str) -> str:
"""List all comments on a task by its UUID."""
result = await execute_on_client(
action="select",
table="taskComments",
filters={"taskId": task_id},
)
rows = result.get("rows", [])
if not rows:
return f"No comments found for task {task_id}."
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
@tool
async def add_task_comment(task_id: str, author: str, content: str) -> str:
"""Add a comment to a task.
task_id: UUID of the task to comment on
author: name or ID of the comment author
content: comment text
"""
result = await execute_on_client(
action="insert",
table="taskComments",
data={"taskId": task_id, "author": author, "content": content},
)
row = result.get("row", {})
row_author = row.get("author", author)
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
row_task_id = row.get("taskId") or row.get("task_id") or task_id
row_comment_id = row.get("id", "unknown")
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
@tool
async def delete_task_comment(comment_id: str) -> str:
"""Delete a task comment by its UUID."""
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
return f"Comment {comment_id} deleted."
# ── Agent ─────────────────────────────────────────────────────────────
TASK_TOOLS: list[Any] = [
list_tasks,
create_task,
update_task,
delete_task,
list_tasks_due_today,
list_task_comments,
add_task_comment,
delete_task_comment,
]

View File

@@ -1,114 +0,0 @@
"""Timeline agent — project milestone management (list, create, update, delete)."""
from __future__ import annotations
import re
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
TIMELINE_SYSTEM_PROMPT = (
"You are a project timeline assistant. Timelines are milestone dates that\n"
"track progress on a project — they are not calendar events.\n\n"
"Rules:\n"
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
" - For update_timeline, use -1 for integer fields you do not want to change\n"
" - Listing without a project_id returns all timelines across projects\n"
" - Always echo the title and formatted date in your confirmation."
)
@tool
async def list_timelines(project_id: str = "") -> str:
"""List timelines. Provide project_id to scope to a specific project."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="timelines",
filters={"projectId": normalized_project_id or None},
)
rows = result.get("rows", [])
if not rows:
return "No timelines found."
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
@tool
async def create_timeline(
project_id: str,
title: str,
date: int,
is_ai_suggested: int = 0,
) -> str:
"""Create a project timeline (milestone).
project_id: REQUIRED UUID of the parent project
title: descriptive name for the milestone
date: Unix timestamp in milliseconds
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
"""
result = await execute_on_client(
action="insert",
table="timelines",
data={
"projectId": project_id,
"title": title,
"date": date,
"isAiSuggested": is_ai_suggested,
},
)
row = result["row"]
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
@tool
async def update_timeline(
timeline_id: str,
title: str = "",
date: int = -1,
) -> str:
"""Update a timeline. Only pass fields that should change.
timeline_id: UUID of the timeline (required)
date: -1 means unchanged; any other value sets the new date (ms timestamp)
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if date != -1:
updates["date"] = date
result = await execute_on_client(
action="update",
table="timelines",
data={"id": timeline_id, "updates": updates},
)
row = result["row"]
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
@tool
async def delete_timeline(timeline_id: str) -> str:
"""Delete a timeline permanently by its UUID."""
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
return f"Timeline {timeline_id} deleted."
TIMELINE_TOOLS: list[Any] = [
list_timelines,
create_timeline,
update_timeline,
delete_timeline,
]

View File

View File

@@ -1,14 +0,0 @@
"""Shared FastAPI dependencies.
``get_current_user`` and ``oauth2_scheme`` live in ``app.api.middleware.auth``
(the canonical location per Step 9). This module re-exports them so that all
existing route imports (``from app.api.deps import get_current_user``) continue
to work without modification.
Step 12 will update ``get_current_user`` to fetch the live tier from PostgreSQL
instead of reading it from the JWT payload.
"""
from app.api.middleware.auth import get_current_user, oauth2_scheme # noqa: F401
__all__ = ["get_current_user", "oauth2_scheme"]

View File

@@ -1,19 +0,0 @@
"""API middleware package.
Exports the three middleware components introduced in Step 9:
- Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme``
- Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter)
- Sanitizer: ``SanitizerMiddleware``
"""
from app.api.middleware.auth import get_current_user, oauth2_scheme
from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter
from app.api.middleware.sanitizer import SanitizerMiddleware
__all__ = [
"get_current_user",
"oauth2_scheme",
"TierRateLimitMiddleware",
"limiter",
"SanitizerMiddleware",
]

View File

@@ -1,80 +0,0 @@
"""Auth middleware — JWT validation dependency.
``get_current_user`` is the FastAPI dependency used by all protected routes.
It decodes the Bearer JWT (identity + expiry), then fetches the current tier
from the ``subscriptions`` table so that tier changes take effect immediately
without requiring token re-issue.
Exempt routes (no JWT required):
- POST /api/v1/auth/register
- POST /api/v1/auth/login
- POST /api/v1/billing/webhook
"""
from __future__ import annotations
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
from app.db import get_session
from app.schemas import UserProfile
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Validate a Bearer JWT and return the authenticated user.
The JWT is used for identity and expiry only. The tier is fetched live
from the ``subscriptions`` table so that upgrades/downgrades take effect
immediately. Falls back to ``'free'`` when no subscription row exists.
Raises HTTP 401 on any invalid or expired token.
"""
credentials_exc = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
if not user_id or not email:
raise credentials_exc
except JWTError:
raise credentials_exc
# Live tier lookup — subscription row is the authoritative source.
# In dev, fall back to 'power' (unlimited) so quota limits don't
# block local development when no Stripe subscription exists.
from app.models import Subscription, User # noqa: PLC0415
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier
# Fetch name/surname from user row.
user_result = await db.execute(
select(User.name, User.surname).where(User.id == user_id)
)
user_row = user_result.one_or_none()
return UserProfile(
id=user_id,
email=email,
name=user_row.name if user_row else None,
surname=user_row.surname if user_row else None,
tier=tier,
) # type: ignore[arg-type]

View File

@@ -1,129 +0,0 @@
"""Tier-aware rate limiting middleware.
Uses a per-user sliding-window counter (in-process, no Redis required).
The ``slowapi`` Limiter is also exported for optional route-level decoration.
Limits (requests per minute):
- free: 20
- pro: 60
- power: 120
- team: 200
Exempt paths bypass the limiter entirely:
- POST /api/v1/auth/register
- POST /api/v1/auth/login
- POST /api/v1/billing/webhook
- GET /api/v1/health
"""
from __future__ import annotations
import json
import time
from collections import defaultdict
from fastapi import Request, Response
from jose import JWTError, jwt
from slowapi import Limiter
from slowapi.util import get_remote_address
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from app.config.settings import settings
_TIER_LIMITS: dict[str, int] = {
"free": 20,
"pro": 60,
"power": 120,
"team": 200,
}
_EXEMPT_PATHS: frozenset[str] = frozenset(
{
"/api/v1/auth/register",
"/api/v1/auth/login",
"/api/v1/billing/webhook",
"/api/v1/health",
}
)
def _get_user_id_from_jwt(request: Request) -> str:
"""Key function for the slowapi Limiter: returns JWT sub or remote IP."""
auth = request.headers.get("Authorization", "")
token = auth.removeprefix("Bearer ").strip()
if not token:
return get_remote_address(request)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
return payload.get("sub") or get_remote_address(request)
except JWTError:
return get_remote_address(request)
# Exported Limiter instance — available for optional route-level decoration.
limiter = Limiter(key_func=_get_user_id_from_jwt)
class TierRateLimitMiddleware(BaseHTTPMiddleware):
"""Sliding-window rate limiter applied globally across all non-exempt routes.
Each authenticated user gets their own 60-second window sized by tier.
Unauthenticated requests pass through (the auth dependency will reject them
with 401 before the route handler runs).
"""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
# user_id → list of request timestamps (float, seconds since epoch)
self._window: dict[str, list[float]] = defaultdict(list)
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
if request.url.path in _EXEMPT_PATHS:
return await call_next(request)
# Extract JWT claims — if no valid token, pass through for auth dep to handle.
auth = request.headers.get("Authorization", "")
token = auth.removeprefix("Bearer ").strip()
if not token:
return await call_next(request)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str = payload.get("sub") or get_remote_address(request)
tier: str = payload.get("tier", "free")
except JWTError:
return await call_next(request)
limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"])
now = time.monotonic()
window_start = now - 60.0
# Slide the window: discard timestamps older than 60 seconds.
timestamps = [t for t in self._window[user_id] if t > window_start]
if len(timestamps) >= limit:
retry_after = max(1, int(60 - (now - min(timestamps))))
return Response(
content=json.dumps(
{
"detail": (
f"Rate limit exceeded ({limit} req/min for {tier} tier). "
f"Retry in {retry_after}s."
)
}
),
status_code=429,
headers={
"Retry-After": str(retry_after),
"Content-Type": "application/json",
},
)
timestamps.append(now)
self._window[user_id] = timestamps
return await call_next(request)

View File

@@ -1,139 +0,0 @@
"""Response sanitizer middleware.
Scans JSON responses from the /api/v1/chat endpoint and strips any fragments
that could reveal server-side prompt IP:
- System prompt openers ("You are a/an/the …")
- Agent routing metadata ("Available agents:", "intent classifier", …)
- LangChain tool schema fragments (``"type": "function"``)
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
- Exact-match known prompt fingerprints
Binary responses (storage blobs, backup data) are never touched — the
middleware only activates for paths under /api/v1/chat.
Any sanitisation event is logged as a WARNING with the request path and the
names of the fields that were modified.
"""
from __future__ import annotations
import json
import logging
import re
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Detection patterns — order matters: fingerprints checked first (exact),
# then compiled regexes.
# ---------------------------------------------------------------------------
_FINGERPRINTS: tuple[str, ...] = (
"You are an intent classifier",
"Respond with just the agent name",
"Summarize these agent results",
"Available agents:",
"route to:",
)
_PATTERNS: tuple[re.Pattern[str], ...] = (
re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL),
re.compile(r"Available agents\s*:", re.IGNORECASE),
re.compile(r"\bintent classifier\b", re.IGNORECASE),
re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema
re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE),
re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers
re.compile(r"route\s+to\s*:", re.IGNORECASE),
re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE),
)
def _sanitize_text(text: str) -> tuple[str, bool]:
"""Scan *text* for prompt fragments and replace matches with ``[REDACTED]``.
Returns ``(cleaned_text, was_changed)``.
"""
# Fingerprint check — if any exact phrase is present, redact the whole string.
for fp in _FINGERPRINTS:
if fp in text:
return "[REDACTED]", True
changed = False
for pattern in _PATTERNS:
new_text, n = pattern.subn("[REDACTED]", text)
if n:
text = new_text
changed = True
return text, changed
class SanitizerMiddleware(BaseHTTPMiddleware):
"""Strip prompt IP from /api/v1/chat JSON responses."""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
response: Response = await call_next(request)
# Only process chat endpoint responses.
if not request.url.path.startswith("/api/v1/chat"):
return response
# Read body — collect streaming chunks.
body_bytes = b""
async for chunk in response.body_iterator:
body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode()
# Skip non-JSON bodies (shouldn't happen on /chat, but be safe).
try:
body = json.loads(body_bytes.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
return Response(
content=body_bytes,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
if not isinstance(body, dict):
return Response(
content=body_bytes,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
# Walk top-level string fields and sanitise.
sanitised_fields: list[str] = []
for key, value in body.items():
if isinstance(value, str):
cleaned, changed = _sanitize_text(value)
if changed:
body[key] = cleaned
sanitised_fields.append(key)
if sanitised_fields:
logger.warning(
"Sanitizer redacted prompt fragments",
extra={
"path": request.url.path,
"fields": sanitised_fields,
},
)
new_body = json.dumps(body).encode("utf-8")
headers = dict(response.headers)
headers["content-length"] = str(len(new_body))
return Response(
content=new_body,
status_code=response.status_code,
headers=headers,
media_type="application/json",
)

View File

@@ -1,406 +0,0 @@
"""Chatbot Journey — WS-based guided conversation to build an agent prompt_template.
The journey is driven entirely through WebSocket frames (no REST endpoints).
The device WS handler dispatches ``journey_start`` and ``journey_message``
frames to the functions exported here.
Journey flow:
1. FE sends ``journey_start`` frame with basic agent config (directory,
data_types, schedule).
2. Server creates an in-memory session, sets up a WS executor so the
setup LLM can use file-system tools, does a first directory scrape,
and sends back a ``journey_reply`` with the first question.
3. FE sends ``journey_message`` frames for each user reply.
4. Server appends the user message, calls the LLM (which may read files
via tools), and sends back a ``journey_reply``.
5. After 3-5 turns the LLM wraps up by emitting a ``prompt_template``
block delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
6. Server parses the block, sends ``journey_reply`` with ``done=True``
and the template. FE stores it locally.
"""
from __future__ import annotations
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
from app.core.llm import get_llm
logger = logging.getLogger(__name__)
# ── Session TTL ───────────────────────────────────────────────────────────
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
# Sentinel strings used to delimit the LLM-produced prompt_template.
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
# Minimum turns before we consider nudging the LLM to wrap up.
_MIN_TURNS_BEFORE_NUDGE: int = 3
# Hard cap to avoid infinite loops (safety net, not the primary stopping criterion).
_MAX_TURNS: int = 15
# Max tool-calling steps per LLM invocation.
_MAX_TOOL_STEPS: int = 6
# ── In-memory session store ───────────────────────────────────────────────
@dataclass
class JourneySession:
session_id: str
user_id: str
agent_type: str # "local" | "cloud"
directory: str
data_types: list[str]
history: list[dict[str, Any]] = field(default_factory=list)
system_prompt: str = ""
created_at: float = field(default_factory=time.monotonic)
def is_expired(self) -> bool:
return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS
# session_id → session
_sessions: dict[str, JourneySession] = {}
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
"""Retrieve session; return None on missing, expired, or wrong owner."""
s = _sessions.get(session_id)
if s is None or s.is_expired():
_sessions.pop(session_id, None)
return None
if s.user_id != user_id:
return None
return s
# ── System prompt builder ─────────────────────────────────────────────────
_SYSTEM_PROMPT_TEMPLATE = """\
You are a friendly assistant helping a freelancer configure a data-extraction agent.
Your job is to understand exactly what data the user wants to extract from their
local directory and produce a detailed prompt_template that a separate AI will use
as its instruction set.
The extraction agent already has this base behaviour built in:
- Reads each file using file-system tools.
- Creates records (tasks, notes, timelines, projects) via CRUD tools.
- Sets isAiSuggested=1 on every new record.
- Only extracts data explicitly present in the files — it never invents information.
The user's custom prompt is appended AFTER this base behaviour, so focus on
what to look for and how to map it — not on the general extraction mechanics.
You have access to file-system tools to explore the user's directory:
- list_directory: to see folder structure
- read_file_content: to peek at file contents
- get_file_metadata: to check file info
The user's configured directory is: {directory}
Target data types: {data_types}
IMPORTANT — project assignment is handled automatically by the main agent runner
before the custom prompt is ever used. You MUST NOT ask the user about projects,
projectId, or how to link records to projects. Never include projectId logic or
project creation instructions in the generated prompt_template.
Start by exploring the directory to understand its structure. Then ask concise,
focused questions one at a time. Cover these topics (not necessarily in this order):
1. The type and format of the source content (confirmed by your exploration).
2. How fields should be mapped (e.g. filename → task title).
3. Priority or status rules (e.g. "urgent" keyword → high priority).
4. Any special handling, date extraction, or exclusions.
Once you reach 90% confidence, output the final prompt_template between these exact
markers on their own lines:
{template_start}
<the complete extraction prompt here>
{template_end}
The prompt_template must be a self-contained instruction for an AI that reads files
and must perform CRUD operations using tools to create records. It should specify:
- What entity types to create (tasks, notes, timelines) — never projects.
- How to map file content to record fields (camelCase: title, status, priority,
dueDate, content, etc.) — never include projectId.
- That isAiSuggested must be set to 1 on every new record.
- Concrete examples of mappings based on what you discovered in the directory.
{existing_section}\
Keep asking clarifying questions until you are at least 90% confident you have
enough information to generate an accurate prompt_template. Once you reach that
confidence level, stop asking and produce the final template immediately.
Begin by exploring the directory, then ask your first question.\
"""
def _build_system_prompt(
directory: str,
data_types: list[str],
existing_template: str | None = None,
) -> str:
existing_section = (
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
f"---\n{existing_template}\n---\n"
if existing_template
else ""
)
return _SYSTEM_PROMPT_TEMPLATE.format(
directory=directory,
data_types=", ".join(data_types),
template_start=_TEMPLATE_START,
template_end=_TEMPLATE_END,
existing_section=existing_section,
)
# ── Template extraction ───────────────────────────────────────────────────
def _extract_template(text: str) -> str | None:
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
return None
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
end_idx = text.index(_TEMPLATE_END)
return text[start_idx:end_idx].strip() or None
# ── LLM call with tool support ───────────────────────────────────────────
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
async def _call_llm_with_tools(
system_prompt: str,
history: list[dict[str, Any]],
tools: list[Any],
) -> str:
"""Build LangChain messages from history and invoke the LLM with tools.
Handles tool-calling loops: if the LLM calls tools, execute them and
continue until a final text response is produced.
"""
messages: list[Any] = [SystemMessage(content=system_prompt)]
for turn in history:
if turn["role"] == "user":
messages.append(HumanMessage(content=turn["content"]))
else:
messages.append(AIMessage(content=turn["content"]))
llm = get_llm(model=None, temperature=0.4)
llm_with_tools = llm.bind_tools(tools)
tool_map = {tool_def.name: tool_def for tool_def in tools}
for _ in range(_MAX_TOOL_STEPS):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return _as_text(response.content)
for call in response.tool_calls:
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"agent_setup: journey tool_call name=%s args=%s",
call_name,
json.dumps(call_args, ensure_ascii=True)[:500],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"agent_setup: journey tool_result name=%s output=%s",
call_name,
str(tool_output)[:800],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
# Fallback: exceeded max steps.
final = await llm.ainvoke(messages)
return _as_text(final.content)
# ── Journey handlers (called from device_ws.py) ──────────────────────────
async def handle_journey_start(
user_id: str,
frame: dict[str, Any],
) -> dict[str, Any]:
"""Handle a ``journey_start`` WS frame.
Creates a session, runs the setup LLM with directory exploration,
and returns the ``journey_reply`` payload.
"""
agent_type = frame.get("agent_type", "local")
directory = frame.get("directory", "")
data_types = frame.get("data_types", [])
existing_template = frame.get("existing_template")
# Use the session_id provided by the FE so the reply matches the
# listener key; fall back to a generated one if absent.
session_id = frame.get("session_id") or str(uuid.uuid4())
system_prompt = _build_system_prompt(directory, data_types, existing_template)
session = JourneySession(
session_id=session_id,
user_id=user_id,
agent_type=agent_type,
directory=directory,
data_types=data_types,
system_prompt=system_prompt,
)
# The LLM will explore the directory using FILESYSTEM_TOOLS via the
# ws_context executor (already set by the WS handler before calling us).
# Seed with an initial user message — some providers (e.g. GitHub Copilot)
# require at least one user/input message to be present.
seed_history: list[dict[str, Any]] = [
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
]
ai_reply = await _call_llm_with_tools(
system_prompt=system_prompt,
history=seed_history,
tools=list(FILESYSTEM_TOOLS),
)
session.history.extend(seed_history)
session.history.append({"role": "assistant", "content": ai_reply})
_sessions[session_id] = session
logger.info(
"agent_setup: journey session %s started for user %s (directory=%s)",
session_id,
user_id,
directory,
)
# Check if the LLM produced the template on the first turn (unlikely but possible).
prompt_template = _extract_template(ai_reply)
done = prompt_template is not None
display_message = ai_reply
if done:
display_message = (
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
or "Here is your agent configuration. You can save it or continue refining."
)
_sessions.pop(session_id, None)
return {
"type": "journey_reply",
"session_id": session_id,
"message": display_message,
"done": done,
"prompt_template": prompt_template,
}
async def handle_journey_message(
user_id: str,
frame: dict[str, Any],
) -> dict[str, Any]:
"""Handle a ``journey_message`` WS frame.
Appends the user message, calls the LLM, and returns the
``journey_reply`` payload.
"""
session_id = frame.get("session_id", "")
message = frame.get("message", "")
session = get_journey_session(session_id, user_id)
if session is None:
return {
"type": "journey_reply",
"session_id": session_id,
"message": "Journey session not found or expired. Please start a new setup.",
"done": True,
"prompt_template": None,
}
# Append user turn.
session.history.append({"role": "user", "content": message})
# Call the LLM with tools.
ai_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=list(FILESYSTEM_TOOLS),
)
session.history.append({"role": "assistant", "content": ai_reply})
# Check if the LLM produced the final template.
prompt_template = _extract_template(ai_reply)
done = prompt_template is not None
# If the LLM didn't produce a template, nudge it once it has asked enough
# questions (>= _MIN_TURNS_BEFORE_NUDGE) or hits the hard safety cap.
if not done:
turns = sum(1 for t in session.history if t["role"] == "user")
if turns >= _MAX_TURNS:
nudge_content = (
"[System: You have enough information. Please generate the final "
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
)
session.history.append({"role": "user", "content": nudge_content})
nudge_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=list(FILESYSTEM_TOOLS),
)
session.history.append({"role": "assistant", "content": nudge_reply})
prompt_template = _extract_template(nudge_reply)
if prompt_template is not None:
done = True
ai_reply = nudge_reply
display_message = ai_reply
if done:
display_message = (
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
if _TEMPLATE_START in ai_reply
else "Here is your agent configuration. You can save it or continue refining."
)
_sessions.pop(session_id, None)
logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id)
return {
"type": "journey_reply",
"session_id": session_id,
"message": display_message,
"done": done,
"prompt_template": prompt_template,
}

View File

@@ -1,222 +0,0 @@
"""Agent routes.
Backend responsibilities are intentionally minimal:
GET /agents/catalog — static catalog for UI display
POST /agents/can-create — billing eligibility check
POST /agents/trigger — trigger a local agent run
Agent configuration is owned by the Electron app and is not persisted
in backend agent-config tables.
"""
from __future__ import annotations
import asyncio
import uuid
from datetime import datetime, timedelta, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import FEATURES
from app.core.agent_runner import is_agent_running, run_local_agent
from app.core.device_manager import device_manager
from app.db import get_session
from app.models import AgentRunLog, LocalAgentConfig
from app.schemas import (
AgentCatalogItem,
AgentCreationCheckRequest,
AgentCreationCheckResponse,
AgentRunLogResponse,
AgentTriggerRequest,
UserProfile,
)
router = APIRouter(prefix="/agents", tags=["agents"])
# ── Datetime helpers ──────────────────────────────────────────────────
def _dt_ms(dt: datetime) -> int:
return int(dt.timestamp() * 1000)
def _dt_ms_opt(dt: datetime | None) -> int | None:
return int(dt.timestamp() * 1000) if dt else None
def _to_data_types(values: list[str]) -> list[str]:
normalize = {
"task": "tasks", "tasks": "tasks",
"note": "notes", "notes": "notes",
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
"project": "projects", "projects": "projects",
}
seen: set[str] = set()
result: list[str] = []
for v in values:
mapped = normalize.get(v)
if mapped and mapped not in seen:
seen.add(mapped)
result.append(mapped)
return result
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
return AgentRunLogResponse(
id=log.id,
agent_id=log.agent_id,
agent_type=log.agent_type, # type: ignore[arg-type]
status=log.status, # type: ignore[arg-type]
items_processed=log.items_processed,
items_created=log.items_created,
errors=log.errors or [],
started_at=_dt_ms(log.started_at),
completed_at=_dt_ms_opt(log.completed_at),
)
def _enforce_agent_limit(tier: str, current_count: int) -> int:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
if limit != -1 and current_count >= limit:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
)
return limit
async def _enforce_run_frequency(
tier: str,
user_id: str,
db: AsyncSession,
) -> None:
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
if limit == -1:
return # unlimited
today_start = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
)
result = await db.execute(
select(func.count(AgentRunLog.id)).where(
AgentRunLog.user_id == user_id,
AgentRunLog.started_at >= today_start,
)
)
runs_today: int = result.scalar_one()
if runs_today >= limit:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
)
# ── Catalog ───────────────────────────────────────────────────────────
@router.get("/catalog", response_model=list[AgentCatalogItem])
async def get_agent_catalog(
current_user: UserProfile = Depends(get_current_user),
) -> list[AgentCatalogItem]:
"""Return the static list of available agent types and their descriptions."""
return [
AgentCatalogItem(
type="local_directory",
name="Local Directory Monitor",
description="Watches local directories, extracts data from files using AI",
),
AgentCatalogItem(
type="gmail",
name="Gmail Connector",
description="Scans Gmail inbox, extracts tasks/notes from emails",
),
AgentCatalogItem(
type="teams",
name="Microsoft Teams Connector",
description="Monitors Teams messages, extracts action items",
),
AgentCatalogItem(
type="outlook",
name="Outlook Connector",
description="Scans Outlook inbox, extracts tasks/notes",
),
]
@router.post("/can-create", response_model=AgentCreationCheckResponse)
async def can_create_agent(
body: AgentCreationCheckRequest,
current_user: UserProfile = Depends(get_current_user),
) -> AgentCreationCheckResponse:
"""Check if the user can create one more agent based on billing tier.
Since configuration is client-owned, the Electron app sends its current
active agent count and the backend applies tier limits.
"""
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
allowed = limit == -1 or body.active_agents < limit
return AgentCreationCheckResponse(
allowed=allowed,
tier=current_user.tier,
active_agents=body.active_agents,
limit=limit,
)
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
async def trigger_agent_run(
body: AgentTriggerRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> AgentRunLogResponse:
"""Trigger a local agent run using client-provided configuration."""
_enforce_agent_limit(current_user.tier, body.active_agents)
await _enforce_run_frequency(current_user.tier, current_user.id, db)
config = LocalAgentConfig(
id=str(uuid.uuid4()),
user_id=current_user.id,
device_id=body.device_id,
name="Local Directory Monitor",
directory_paths=[body.directory],
data_types=_to_data_types(body.what_to_extract),
prompt_template=body.custom_agent_prompt,
file_extensions=[],
schedule_cron=body.batch_interval,
enabled=True,
)
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
stable_agent_id = body.agent_id or config.id
if is_agent_running(stable_agent_id):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Agent is already running. Only one run per agent is allowed at a time.",
)
run_log = AgentRunLog(
agent_id=stable_agent_id,
agent_type="local",
user_id=current_user.id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_context = {
"type": "agent_batch",
"run_id": run_log.id,
"agent_id": stable_agent_id,
}
asyncio.create_task(
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
)
return _to_run_log_response(run_log)

View File

@@ -1,235 +0,0 @@
"""Auth routes: register, login, refresh, me.
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
SHA-256 hashes so plaintext never reaches the DB.
"""
from __future__ import annotations
import hashlib
import time
import uuid
from datetime import datetime, timedelta, timezone
import bcrypt
from cryptography.fernet import Fernet
from fastapi import APIRouter, Depends, HTTPException, status
from jose import jwt
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.config.settings import settings
from app.db import get_session
from app.models import RefreshToken, User
from app.schemas import AuthTokens, UserProfile
router = APIRouter(prefix="/auth", tags=["auth"])
# ── Internal helpers ─────────────────────────────────────────────────
def _hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def _verify_password(password: str, hashed: str) -> bool:
return bcrypt.checkpw(password.encode(), hashed.encode())
def _hash_token(plain_token: str) -> str:
"""SHA-256 of the plain refresh token string."""
return hashlib.sha256(plain_token.encode()).hexdigest()
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
"""Return (signed JWT, expires_at_ms)."""
now = int(time.time())
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
payload = {
"sub": user_id,
"email": email,
"tier": tier,
"exp": exp,
"iat": now,
}
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
return token, exp * 1000 # ms for client
# ── Request bodies ────────────────────────────────────────────────────
class _RegisterRequest(BaseModel):
email: str
password: str
name: str | None = None
surname: str | None = None
class _LoginRequest(BaseModel):
email: str
password: str
class _RefreshRequest(BaseModel):
refresh_token: str
# ── Routes ────────────────────────────────────────────────────────────
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
async def register(
body: _RegisterRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Create a new account and return JWT tokens."""
existing = await db.execute(select(User).where(User.email == body.email))
if existing.scalar_one_or_none() is not None:
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
user = User(
id=str(uuid.uuid4()),
email=body.email,
name=body.name,
surname=body.surname,
password_hash=_hash_password(body.password),
tier="free",
encryption_key=Fernet.generate_key().decode(),
)
db.add(user)
await db.flush() # get user.id without committing
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.post("/login", response_model=AuthTokens)
async def login(
body: _LoginRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Validate credentials and return JWT tokens."""
result = await db.execute(select(User).where(User.email == body.email))
user = result.scalar_one_or_none()
if user is None or not _verify_password(body.password, user.password_hash):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.post("/refresh", response_model=AuthTokens)
async def refresh(
body: _RefreshRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Rotate a refresh token and return a new token pair."""
token_hash = _hash_token(body.refresh_token)
result = await db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
rt = result.scalar_one_or_none()
now = datetime.now(timezone.utc)
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
# Rotate: delete old token, issue new one.
await db.delete(rt)
user_result = await db.execute(select(User).where(User.id == rt.user_id))
user = user_result.scalar_one_or_none()
if user is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
plain_token = str(uuid.uuid4())
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
new_rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=new_expires,
)
db.add(new_rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
class _UpdateProfileRequest(BaseModel):
name: str | None = None
surname: str | None = None
@router.get("/me", response_model=UserProfile)
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
"""Return the profile for the authenticated user."""
return current_user
@router.put("/me", response_model=UserProfile)
async def update_profile(
body: _UpdateProfileRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Update the authenticated user's name and surname."""
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
if body.name is not None:
user.name = body.name
if body.surname is not None:
user.surname = body.surname
await db.commit()
await db.refresh(user)
return UserProfile(
id=user.id,
email=user.email,
name=user.name,
surname=user.surname,
tier=current_user.tier,
)

View File

@@ -1,171 +0,0 @@
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the
PostgreSQL ``backup_metadata`` table.
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
treating "history" as a ``{backup_id}`` path parameter.
"""
from __future__ import annotations
import uuid
from email.utils import parsedate_to_datetime
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.db import get_session
from app.models import BackupMetadata as BackupMetadataModel
from app.schemas import BackupMetadata, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
router = APIRouter(prefix="/backup", tags=["backup"])
_blob_store = BlobStore()
async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int:
"""Return total backup bytes stored by *user_id*."""
result = await db.execute(
select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where(
BackupMetadataModel.user_id == user_id
)
)
return int(result.scalar_one())
async def _check_backup_quota(
user: UserProfile, size_bytes: int, db: AsyncSession
) -> None:
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
current = await _current_backup_bytes(user.id, db)
tier_manager.enforce_backup_quota(
user.tier, current_bytes=current, additional_bytes=size_bytes
)
@router.put("")
async def upload_backup(
request: Request,
x_backup_version: int = Header(..., alias="X-Backup-Version"),
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Upload an E2E-encrypted backup blob.
Metadata is passed via custom headers; the raw body is the encrypted blob.
"""
blob = await request.body()
reject_if_tampered(blob, x_backup_checksum)
await _check_backup_quota(current_user, len(blob), db)
s3_key = await _blob_store.upload(
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
)
row = BackupMetadataModel(
id=str(uuid.uuid4()),
user_id=current_user.id,
s3_key=s3_key,
version=x_backup_version,
timestamp=x_backup_timestamp,
checksum=x_backup_checksum,
size_bytes=len(blob),
)
db.add(row)
await db.commit()
return {"ok": True}
@router.get("/history", response_model=list[BackupMetadata])
async def backup_history(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[BackupMetadata]:
"""Return backup metadata records for the authenticated user (no blob bytes)."""
result = await db.execute(
select(BackupMetadataModel)
.where(BackupMetadataModel.user_id == current_user.id)
.order_by(BackupMetadataModel.timestamp.desc())
)
rows = result.scalars().all()
return [
BackupMetadata(
version=r.version,
timestamp=r.timestamp,
checksum=r.checksum,
chunk_count=1,
)
for r in rows
]
@router.get("")
async def download_backup(
request: Request,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> Response:
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
result = await db.execute(
select(BackupMetadataModel)
.where(BackupMetadataModel.user_id == current_user.id)
.order_by(BackupMetadataModel.timestamp.desc())
.limit(1)
)
latest = result.scalar_one_or_none()
if latest is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
ims_header = request.headers.get("If-Modified-Since")
if ims_header:
try:
ims_dt = parsedate_to_datetime(ims_header)
ims_ms = int(ims_dt.timestamp() * 1000)
if latest.timestamp <= ims_ms:
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
except Exception:
pass # malformed header — ignore and serve the blob
blob = await _blob_store.download(current_user.id, latest.s3_key)
return Response(
content=blob,
media_type="application/octet-stream",
headers={
"X-Backup-Version": str(latest.version),
"X-Backup-Timestamp": str(latest.timestamp),
"X-Checksum": latest.checksum,
},
)
@router.delete("/{backup_id}", response_model=dict)
async def delete_backup(
backup_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a specific backup by ID."""
result = await db.execute(
select(BackupMetadataModel).where(
BackupMetadataModel.id == backup_id,
BackupMetadataModel.user_id == current_user.id,
)
)
target = result.scalar_one_or_none()
if target is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
await _blob_store.delete(current_user.id, target.s3_key)
await db.delete(target)
await db.commit()
return {"ok": True}

View File

@@ -1,85 +0,0 @@
"""Billing routes: Stripe checkout, webhook, subscription management.
Business logic lives in ``app.billing.stripe_service.StripeService``.
The route layer handles HTTP concerns (request parsing, response shaping)
and delegates everything else to the service singleton.
"""
from __future__ import annotations
from typing import Any
from fastapi import APIRouter, Depends, Header, Request, status
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.stripe_service import stripe_service
from app.db import get_session
from app.schemas import BillingTier, UserProfile
router = APIRouter(prefix="/billing", tags=["billing"])
# ── Request bodies ─────────────────────────────────────────────────────
class _CheckoutRequest(BaseModel):
tier: BillingTier
# ── Routes ─────────────────────────────────────────────────────────────
@router.post("/checkout", response_model=dict)
async def create_checkout(
body: _CheckoutRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, str]:
"""Create a Stripe checkout session for a tier upgrade.
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
"""
url = stripe_service.create_checkout_session(current_user.id, body.tier)
return {"checkout_url": url}
@router.post("/webhook", response_model=dict)
async def stripe_webhook(
request: Request,
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Handle Stripe webhook events.
No JWT auth — authenticated via Stripe signature verification instead.
Returns 200 immediately when Stripe is not configured (local dev).
"""
payload = await request.body()
await stripe_service.handle_webhook(payload, stripe_signature, db)
return {"ok": True}
@router.get("/subscription", response_model=dict)
async def get_subscription(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""Return the current subscription info for the authenticated user."""
sub = await stripe_service.get_subscription(current_user.id, db)
if sub is None:
return {
"tier": current_user.tier,
"status": "free",
"stripe_subscription_id": None,
"current_period_end": None,
}
return sub
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
async def cancel_subscription(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Cancel the active subscription."""
await stripe_service.cancel_subscription(current_user.id, db)
return {"ok": True}

View File

@@ -1,29 +0,0 @@
"""Chat routes: POST /chat (REST fallback).
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
"""
from __future__ import annotations
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from app.api.deps import get_current_user
from app.core.deep_agent import run_home
from app.schemas import ChatRequest, UserProfile
router = APIRouter(prefix="/chat", tags=["chat"])
@router.post("")
async def chat(
body: ChatRequest,
current_user: UserProfile = Depends(get_current_user),
) -> JSONResponse:
"""REST fallback for home chat when websocket streaming is unavailable."""
response = await run_home(
user_id=current_user.id,
message=body.message,
context=body.context.model_dump(),
)
return JSONResponse(content={"response": response})

View File

@@ -1,417 +0,0 @@
"""Device WebSocket endpoint.
Persistent connection from Electron devices to the backend.
WS /api/v1/ws/device?token=<jwt>
Auth: JWT passed as ``?token=`` query parameter (Bearer header is not
available during the WebSocket handshake).
Protocol:
1. Client connects → JWT validated → connection accepted.
2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``.
3. Backend registers the connection in ``DeviceConnectionManager``.
4. Session enters message dispatch loop + heartbeat.
Incoming frame dispatch:
- ``tool_result`` → resolves a pending tool-call Future.
- ``journey_start`` → starts a guided setup journey session.
- ``journey_message`` → continues a journey conversation.
- ``pong`` → heartbeat acknowledgement (updates last-seen).
- unknown types → logged, ignored.
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
On disconnect:
- Unregisters from DeviceConnectionManager.
- Marks all in-progress AgentRunLog rows for this user as ``error``
with message "device disconnected".
"""
from __future__ import annotations
import asyncio
import json
import logging
from uuid import uuid4
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from jose import JWTError, jwt
from sqlalchemy import update
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs
from app.core.deep_agent import run_floating_stream, run_home_stream
from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware
from app.core.output_formatter import StreamFormatter
from app.core.ws_context import clear_client_executor, set_client_executor
from app.db import async_session
from app.models import AgentRunLog
from app.schemas import WsFrameType
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ws", tags=["device-ws"])
_HEARTBEAT_INTERVAL = 30 # seconds
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
@router.websocket("/device")
async def device_ws(websocket: WebSocket) -> None:
"""Persistent WebSocket endpoint for Electron device connections.
Authentication is via ``?token=<jwt>`` query parameter.
"""
# ── 1. Authenticate before accepting ─────────────────────────────
token = websocket.query_params.get("token", "")
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str | None = payload.get("sub")
if not user_id:
raise JWTError("missing sub")
except JWTError:
await websocket.close(code=1008) # Policy Violation
return
await websocket.accept()
# ── 2. Await device_hello frame ───────────────────────────────────
try:
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
except (asyncio.TimeoutError, WebSocketDisconnect):
await websocket.close(code=1008)
return
try:
hello = json.loads(raw)
if hello.get("type") != WsFrameType.device_hello:
raise ValueError("expected device_hello as first frame")
device_id: str = hello["device_id"]
agent_ids: list[str] = hello.get("agent_ids", [])
except (KeyError, ValueError, json.JSONDecodeError) as exc:
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
await websocket.close(code=1008)
return
# ── 3. Register connection ────────────────────────────────────────
device_manager.register(user_id, device_id, websocket)
logger.info(
"device_ws: connected user=%s device=%s agents=%s",
user_id,
device_id,
agent_ids,
)
# Trigger any overdue agent runs now that the device is connected.
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
# ── 4. Concurrent message loop + heartbeat ────────────────────────
try:
await asyncio.gather(
_message_loop(websocket, user_id),
_heartbeat_loop(websocket),
)
except WebSocketDisconnect:
pass
except Exception as exc:
logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc)
finally:
device_manager.unregister(user_id)
logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id)
await _mark_runs_disconnected(user_id)
# ── Message dispatch loop ─────────────────────────────────────────────
async def _message_loop(websocket: WebSocket, user_id: str) -> None:
"""Receive frames from Electron and dispatch to the appropriate handler."""
async for raw in websocket.iter_text():
try:
frame: dict = json.loads(raw)
except json.JSONDecodeError:
logger.warning("device_ws: invalid JSON from user=%s", user_id)
continue
frame_type = frame.get("type")
if frame_type == WsFrameType.tool_result:
call_id = frame.get("id")
if call_id:
device_manager.resolve_pending_call(user_id, call_id, frame)
else:
logger.warning(
"device_ws: tool_result missing id from user=%s", user_id
)
elif frame_type == WsFrameType.home_request:
asyncio.create_task(
_handle_home_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.floating_request:
asyncio.create_task(
_handle_floating_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.journey_start:
asyncio.create_task(
_handle_journey_start(websocket, user_id, frame)
)
elif frame_type == WsFrameType.journey_message:
asyncio.create_task(
_handle_journey_message(websocket, user_id, frame)
)
elif frame_type == "pong":
# Heartbeat ack — nothing to do, connection is alive.
pass
else:
logger.debug(
"device_ws: unknown frame type %r from user=%s", frame_type, user_id
)
# ── v3 Chat Handlers ──────────────────────────────────────────────────
async def _make_ws_executor(websocket: WebSocket, user_id: str):
"""Return a callback that sends tool_call frames and awaits tool_result."""
async def _executor(payload: dict) -> dict:
payload["type"] = WsFrameType.tool_call
await websocket.send_text(json.dumps(payload))
future = device_manager.create_pending_call(user_id, payload["id"])
return await future
return _executor
async def _handle_home_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a home_request frame — streams HomeFormatter output back on the socket."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
logger.info(
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
user_id,
request_id,
session_id,
message[:200],
)
# ── Memory: enrich context before LLM call ────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id,
message,
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"conversation_history": frame.get("conversation_history", []),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_home_stream(user_id, message, context)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
# Collect text chunks to build the full response for episode storage
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc:
logger.error(
"device_ws: home_request failed user=%s req=%s: %s",
user_id, request_id, exc,
)
finally:
clear_client_executor()
# ── Memory: store episode after response ──────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
)
logger.info(
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
user_id,
request_id,
session_id,
len("".join(response_chunks)),
)
async def _handle_floating_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a floating_request frame — streams FloatingFormatter output back on the socket."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
scope: dict = frame.get("scope", {})
logger.info(
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
user_id,
request_id,
session_id,
json.dumps(scope, ensure_ascii=True)[:200],
message[:200],
)
# ── Memory: enrich context before LLM call ────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id,
message,
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"scope": scope,
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_floating_stream(user_id, message, context)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc:
logger.error(
"device_ws: floating_request failed user=%s req=%s: %s",
user_id, request_id, exc,
)
finally:
clear_client_executor()
# ── Memory: store episode after response ──────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
)
logger.info(
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
user_id,
request_id,
session_id,
len("".join(response_chunks)),
)
# ── v4 Journey Handlers ─────────────────────────────────────────────
async def _handle_journey_start(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a journey_start frame — explores directory and sends first question."""
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
try:
reply = await handle_journey_start(user_id, frame)
await websocket.send_text(json.dumps(reply))
except Exception as exc:
logger.error(
"device_ws: journey_start failed user=%s: %s", user_id, exc
)
await websocket.send_text(json.dumps({
"type": "journey_reply",
"session_id": frame.get("session_id", ""),
"message": f"Failed to start journey: {exc}",
"done": True,
"prompt_template": None,
}))
finally:
clear_client_executor()
async def _handle_journey_message(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a journey_message frame — continues the journey conversation."""
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
try:
reply = await handle_journey_message(user_id, frame)
await websocket.send_text(json.dumps(reply))
except Exception as exc:
session_id = frame.get("session_id", "")
logger.error(
"device_ws: journey_message failed user=%s session=%s: %s",
user_id, session_id, exc,
)
await websocket.send_text(json.dumps({
"type": "journey_reply",
"session_id": session_id,
"message": f"Journey error: {exc}",
"done": True,
"prompt_template": None,
}))
finally:
clear_client_executor()
# ── Heartbeat ─────────────────────────────────────────────────────────
async def _heartbeat_loop(websocket: WebSocket) -> None:
"""Send a ping frame every 30 s to keep the connection alive."""
while True:
await asyncio.sleep(_HEARTBEAT_INTERVAL)
await websocket.send_text(json.dumps({"type": "ping"}))
# ── Disconnect cleanup ────────────────────────────────────────────────
async def _mark_runs_disconnected(user_id: str) -> None:
"""Mark all in-progress AgentRunLog rows as 'error' for this user."""
try:
async with async_session() as db:
await db.execute(
update(AgentRunLog)
.where(
AgentRunLog.user_id == user_id,
AgentRunLog.status == "running",
)
.values(
status="error",
errors=["device disconnected"],
)
)
await db.commit()
except Exception as exc:
logger.error(
"device_ws: failed to mark runs as disconnected for user=%s: %s",
user_id,
exc,
)

View File

@@ -1,148 +0,0 @@
"""Plugins routes: browse and install plugins from the marketplace.
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that
persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables.
"""
from __future__ import annotations
from typing import Any, Literal
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.db import get_session
from app.marketplace.plugin_registry import registry
from app.marketplace.revenue_share import revenue_share
from app.models import PluginInstallation, PluginReview as PluginReviewModel
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
router = APIRouter(prefix="/plugins", tags=["plugins"])
# ── Tier gate ─────────────────────────────────────────────────────────
def _require_plugin_tier(user: UserProfile) -> None:
"""Raise HTTP 403 for users below Power tier."""
if user.tier not in ("power", "team"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Plugin marketplace requires Power tier or above",
)
# ── Local detail schema ────────────────────────────────────────────────
class _PluginDetail(BaseModel):
plugin: PluginManifest
install_count: int
ratings: list[Any]
# ── Routes ────────────────────────────────────────────────────────────
@router.get("", response_model=PluginListResponse)
async def list_plugins(
category: str | None = Query(default=None),
q: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> PluginListResponse:
"""Browse the plugin marketplace. Requires Power tier or above."""
_require_plugin_tier(current_user)
return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort)
@router.get("/{plugin_id}", response_model=_PluginDetail)
async def get_plugin(
plugin_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _PluginDetail:
"""Get full plugin details including install count. Requires Power tier or above."""
_require_plugin_tier(current_user)
entry = await registry.get_plugin(db, plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
# Fetch review ratings for this plugin
review_result = await db.execute(
select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id)
)
reviews = review_result.scalars().all()
ratings = [
{
"reviewer_id": r.reviewer_id,
"decision": r.decision,
"notes": r.notes,
"reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None,
}
for r in reviews
]
return _PluginDetail(
plugin=entry["manifest"],
install_count=entry["install_count"],
ratings=ratings,
)
@router.post("/{plugin_id}/install", response_model=dict)
async def install_plugin(
plugin_id: str,
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
Requires Power tier or above.
"""
_require_plugin_tier(current_user)
entry = await registry.get_plugin(db, plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
# Record the installation in plugin_installations
installation = PluginInstallation(
plugin_id=plugin_id,
user_id=current_user.id,
)
db.add(installation)
await db.flush()
await revenue_share.record_install(
db,
plugin_id=plugin_id,
user_id=current_user.id,
amount_cents=entry["manifest"].price_cents,
)
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
return {"ok": True, "download_url": download_url}
@router.delete("/{plugin_id}/install", response_model=dict)
async def uninstall_plugin(
plugin_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Unregister a plugin installation."""
result = await db.execute(
select(PluginInstallation).where(
PluginInstallation.plugin_id == plugin_id,
PluginInstallation.user_id == current_user.id,
)
)
installation = result.scalar_one_or_none()
if installation is not None:
await db.delete(installation)
await db.commit()
await registry.record_uninstall(db, plugin_id)
return {"ok": True}

View File

@@ -1,195 +0,0 @@
"""Storage routes: CRUD for E2E-encrypted cloud records.
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
PostgreSQL ``storage_records`` table.
"""
from __future__ import annotations
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.db import get_session
from app.models import StorageRecord
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
router = APIRouter(prefix="/storage", tags=["storage"])
_blob_store = BlobStore()
# ── Local response schemas ─────────────────────────────────────────────
class _CreateResponse(BaseModel):
id: str
created_at: int
class _RecordMeta(BaseModel):
id: str
table: str
checksum: str
created_at: int
updated_at: int
# ── Helpers ────────────────────────────────────────────────────────────
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
"""Return total bytes stored by *user_id*."""
result = await db.execute(
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
StorageRecord.user_id == user_id
)
)
return int(result.scalar_one())
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
current = await _current_usage_bytes(user.id, db)
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
async def _get_record_for_user(
record_id: str, user_id: str, db: AsyncSession
) -> StorageRecord:
"""Look up a record and verify ownership. Returns 404 on mismatch
to prevent user enumeration attacks."""
result = await db.execute(
select(StorageRecord).where(
StorageRecord.id == record_id, StorageRecord.user_id == user_id
)
)
record = result.scalar_one_or_none()
if record is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
return record
# ── Routes ─────────────────────────────────────────────────────────────
@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED)
async def create_record(
body: StorageRecordCreate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _CreateResponse:
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
reject_if_tampered(body.blob, body.checksum)
await _check_quota(current_user, len(body.blob), db)
record_id = str(uuid.uuid4())
s3_key = await _blob_store.upload(
current_user.id, body.table, record_id, body.blob, body.checksum
)
record = StorageRecord(
id=record_id,
user_id=current_user.id,
table_name=body.table,
s3_key=s3_key,
checksum=body.checksum,
size_bytes=len(body.blob),
)
db.add(record)
await db.commit()
await db.refresh(record)
created_at_ms = int(record.created_at.timestamp() * 1000)
return _CreateResponse(id=record_id, created_at=created_at_ms)
@router.get("/records", response_model=list[_RecordMeta])
async def list_records(
table: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
limit: int = Query(default=50, ge=1, le=200),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[_RecordMeta]:
"""List record metadata for the authenticated user. Blob bytes are never returned."""
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
if table is not None:
query = query.where(StorageRecord.table_name == table)
query = query.offset((page - 1) * limit).limit(limit)
result = await db.execute(query)
rows = result.scalars().all()
return [
_RecordMeta(
id=r.id,
table=r.table_name,
checksum=r.checksum,
created_at=int(r.created_at.timestamp() * 1000),
updated_at=int(r.updated_at.timestamp() * 1000),
)
for r in rows
]
@router.get("/records/{record_id}")
async def download_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> Response:
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
record = await _get_record_for_user(record_id, current_user.id, db)
blob = await _blob_store.download(current_user.id, record.s3_key)
return Response(
content=blob,
media_type="application/octet-stream",
headers={"X-Checksum": record.checksum},
)
@router.put("/records/{record_id}", response_model=dict)
async def update_record(
record_id: str,
body: StorageRecordUpdate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Replace the blob for an existing record. Verifies checksum before storing."""
record = await _get_record_for_user(record_id, current_user.id, db)
reject_if_tampered(body.blob, body.checksum)
delta = len(body.blob) - record.size_bytes
if delta > 0:
await _check_quota(current_user, delta, db)
s3_key = await _blob_store.upload(
current_user.id, record.table_name, record_id, body.blob, body.checksum
)
record.s3_key = s3_key
record.checksum = body.checksum
record.size_bytes = len(body.blob)
await db.commit()
return {"ok": True}
@router.delete("/records/{record_id}", response_model=dict)
async def delete_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a record and its S3 blob."""
record = await _get_record_for_user(record_id, current_user.id, db)
await _blob_store.delete(current_user.id, record.s3_key)
await db.delete(record)
await db.commit()
return {"ok": True}

View File

@@ -1,79 +0,0 @@
"""Vectors routes: upsert, search, delete cloud vector store entries, and embed text."""
from __future__ import annotations
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.core.llm import embed
from app.schemas import (
UserProfile,
VectorSearchRequest,
VectorSearchResponse,
VectorUpsertRequest,
)
from app.storage.encryption import reject_if_tampered
from app.storage.vector_store import VectorStore
router = APIRouter(prefix="/storage", tags=["vectors"])
_vector_store = VectorStore()
class _VectorDeleteRequest(BaseModel):
ids: list[str]
class _EmbedRequest(BaseModel):
text: str
class _EmbedResponse(BaseModel):
vector: list[float]
@router.post("/vectors/upsert", response_model=dict)
async def upsert_vectors(
body: VectorUpsertRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, int]:
"""Verify checksums and store encrypted vectors in the user-scoped namespace."""
for item in body.vectors:
reject_if_tampered(item.blob, item.checksum)
await _vector_store.upsert(current_user.id, body.vectors)
return {"upserted": len(body.vectors)}
@router.post("/vectors/search", response_model=VectorSearchResponse)
async def search_vectors(
body: VectorSearchRequest,
current_user: UserProfile = Depends(get_current_user),
) -> VectorSearchResponse:
"""Search the user-scoped vector namespace with an encrypted query blob."""
results = await _vector_store.search(current_user.id, body.query_blob, body.top_k)
return VectorSearchResponse(results=results)
@router.delete("/vectors", response_model=dict)
async def delete_vectors(
body: _VectorDeleteRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, bool]:
"""Delete vectors by ID, scoped to the authenticated user."""
await _vector_store.delete(current_user.id, body.ids)
return {"ok": True}
@router.post("/vectors/embed", response_model=_EmbedResponse)
async def embed_text(
body: _EmbedRequest,
current_user: UserProfile = Depends(get_current_user),
) -> _EmbedResponse:
"""Generate a 1536-dim embedding vector for the given text.
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
Used by backend tools (note_agent) and Electron (vectordb.ts) alike.
"""
vector = await embed(body.text)
return _EmbedResponse(vector=vector)

View File

@@ -1,4 +0,0 @@
from app.billing.stripe_service import stripe_service
from app.billing.tier_manager import tier_manager
__all__ = ["stripe_service", "tier_manager"]

View File

@@ -1,256 +0,0 @@
"""Stripe service: checkout sessions, webhook handling, subscription management.
Subscription records are persisted in the PostgreSQL ``subscriptions`` table.
All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not
configured, enabling local development without live credentials.
"""
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
import stripe as stripe_lib
from fastapi import HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
# Stripe price IDs per tier — replace with real IDs in production .env
TIER_PRICE_IDS: dict[str, str] = {
"pro": "price_pro_monthly",
"power": "price_power_monthly",
"team": "price_team_monthly",
}
class StripeService:
"""Wraps all Stripe interactions and owns subscription persistence."""
# ── Internal helpers ────────────────────────────────────────────────
def _configured(self) -> bool:
return bool(settings.STRIPE_SECRET_KEY)
def _client(self) -> Any:
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
return stripe_lib
# ── Public API ──────────────────────────────────────────────────────
def create_checkout_session(
self,
user_id: str,
tier: str,
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
cancel_url: str = "https://app.adiuva.app/billing/cancel",
) -> str:
"""Create a Stripe checkout session and return the URL.
Returns a stub URL when Stripe is not configured.
Raises ``HTTP 400`` for the free tier or an unknown tier.
"""
if tier == "free":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot create a checkout session for the free tier",
)
price_id = TIER_PRICE_IDS.get(tier)
if not price_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unknown tier: {tier}",
)
if not self._configured():
return "https://stripe.com/stub-checkout"
s = self._client()
session = s.checkout.Session.create(
payment_method_types=["card"],
mode="subscription",
line_items=[{"price": price_id, "quantity": 1}],
success_url=success_url,
cancel_url=cancel_url,
metadata={"user_id": user_id, "tier": tier},
)
return session.url
async def handle_webhook(
self,
payload: bytes,
sig_header: str,
db: AsyncSession,
) -> None:
"""Process a Stripe webhook event.
Verifies the signature, then dispatches on event type.
Raises ``HTTP 400`` on signature mismatch.
No-ops when Stripe is not configured.
"""
if not self._configured():
return
try:
s = self._client()
event = s.Webhook.construct_event(
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
)
except stripe_lib.error.SignatureVerificationError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid Stripe signature",
)
event_type: str = event["type"]
data: dict[str, Any] = event["data"]["object"]
if event_type == "checkout.session.completed":
user_id = data.get("metadata", {}).get("user_id")
tier = data.get("metadata", {}).get("tier", "free")
sub_id = data.get("subscription")
period_end_ts = data.get("current_period_end")
period_end = (
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
if period_end_ts
else None
)
if user_id:
await self._upsert_subscription(
db, user_id, sub_id, tier, "active", period_end
)
elif event_type == "customer.subscription.updated":
sub_id = data.get("id")
new_status = data.get("status", "active")
period_end_ts = data.get("current_period_end")
period_end = (
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
if period_end_ts
else None
)
if sub_id:
await self._update_subscription_by_stripe_id(
db, sub_id, status=new_status, current_period_end=period_end
)
elif event_type == "customer.subscription.deleted":
sub_id = data.get("id")
if sub_id:
await self._update_subscription_by_stripe_id(
db, sub_id, tier="free", status="canceled"
)
elif event_type == "invoice.payment_failed":
sub_id = data.get("subscription")
if sub_id:
await self._update_subscription_by_stripe_id(
db, sub_id, status="past_due"
)
await db.commit()
async def get_subscription(
self, user_id: str, db: AsyncSession
) -> dict[str, Any] | None:
"""Return the subscription record for ``user_id``, or ``None`` if absent."""
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription).where(Subscription.user_id == user_id)
)
sub = result.scalar_one_or_none()
if sub is None:
return None
return {
"tier": sub.tier,
"stripe_subscription_id": sub.stripe_subscription_id,
"status": sub.status,
"current_period_end": (
int(sub.current_period_end.timestamp() * 1000)
if sub.current_period_end
else None
),
}
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
"""Cancel the user's Stripe subscription and downgrade them to free.
Raises ``HTTP 404`` when no active subscription exists.
"""
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription).where(Subscription.user_id == user_id)
)
sub = result.scalar_one_or_none()
if sub is None or not sub.stripe_subscription_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No active subscription found",
)
if self._configured():
s = self._client()
s.Subscription.cancel(sub.stripe_subscription_id)
sub.tier = "free"
sub.status = "canceled"
await db.commit()
# ── Private DB helpers ───────────────────────────────────────────────
async def _upsert_subscription(
self,
db: AsyncSession,
user_id: str,
stripe_subscription_id: str | None,
tier: str,
sub_status: str,
current_period_end: datetime | None,
) -> None:
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription).where(Subscription.user_id == user_id)
)
sub = result.scalar_one_or_none()
if sub is None:
sub = Subscription(user_id=user_id)
db.add(sub)
sub.stripe_subscription_id = stripe_subscription_id
sub.tier = tier
sub.status = sub_status
sub.current_period_end = current_period_end
async def _update_subscription_by_stripe_id(
self,
db: AsyncSession,
stripe_subscription_id: str,
*,
tier: str | None = None,
status: str | None = None,
current_period_end: datetime | None = None,
) -> None:
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription).where(
Subscription.stripe_subscription_id == stripe_subscription_id
)
)
sub = result.scalar_one_or_none()
if sub is None:
return
if tier is not None:
sub.tier = tier
if status is not None:
sub.status = status
if current_period_end is not None:
sub.current_period_end = current_period_end
# Module-level singleton shared across the app.
stripe_service = StripeService()

View File

@@ -1,195 +0,0 @@
"""Tier manager: feature matrix and quota enforcement.
``TierManager`` is the single source of truth for what each billing tier
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
Quota-enforcement helpers take ``tier`` directly — the caller already has it
from ``current_user.tier`` (provided by ``get_current_user``).
"""
from __future__ import annotations
from typing import Any
from fastapi import HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.schemas import BillingTier
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
FEATURES: dict[str, dict[str, Any]] = {
"free": {
"agents": 3,
"batch_active": 2,
"batch_runs_per_day": 5,
"cloud_storage_gb": 0,
"backup_gb": 0,
"providers": 1,
"batch_builder": False,
"plugin_marketplace": False,
"sso": False,
},
"pro": {
"agents": -1, # unlimited
"batch_active": 10,
"batch_runs_per_day": 50,
"cloud_storage_gb": 5,
"backup_gb": 5,
"providers": -1,
"batch_builder": False,
"plugin_marketplace": False,
"sso": False,
},
"power": {
"agents": -1,
"batch_active": -1, # unlimited
"batch_runs_per_day": -1, # unlimited
"cloud_storage_gb": 25,
"backup_gb": 25,
"providers": -1,
"batch_builder": True,
"plugin_marketplace": True,
"sso": False,
},
"team": {
"agents": -1,
"batch_active": -1,
"batch_runs_per_day": -1, # unlimited
"cloud_storage_gb": -1, # unlimited
"backup_gb": -1, # unlimited
"providers": -1,
"batch_builder": True,
"plugin_marketplace": True,
"sso": True,
},
}
# Requests-per-minute limit per tier.
RATE_LIMITS: dict[str, int] = {
"free": 20,
"pro": 60,
"power": 120,
"team": 200,
}
class TierManager:
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
# ── Tier lookup ─────────────────────────────────────────────────────
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
"""Return the current billing tier for ``user_id`` from the DB.
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod
when no subscription row exists.
"""
from app.models import Subscription # noqa: PLC0415
from app.config.settings import settings # noqa: PLC0415
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
tier: str | None = result.scalar_one_or_none()
if tier is None or tier not in FEATURES:
return "power" if settings.ENV == "dev" else "free"
return tier # type: ignore[return-value]
# ── Feature access ───────────────────────────────────────────────────
def check_feature(self, tier: BillingTier, feature: str) -> bool:
"""Return ``True`` if ``tier`` has ``feature`` enabled.
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
"""
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
if value is None:
return False
if isinstance(value, bool):
return value
return value != 0
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
if not self.check_feature(tier, feature):
detail = (
f"Feature '{feature}' requires {tier_name} tier or above."
if tier_name
else f"Feature '{feature}' is not available on your current tier."
)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
# ── Rate limiting ────────────────────────────────────────────────────
def get_rate_limit(self, tier: BillingTier) -> int:
"""Return the requests-per-minute limit for ``tier``."""
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
# ── Storage quota ────────────────────────────────────────────────────
def enforce_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
``tier`` is the caller's current tier (from ``current_user.tier``).
``current_bytes`` is the total bytes already stored (queried by caller).
"""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Cloud storage is not available on the '{tier}' tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Storage quota exceeded for tier '{tier}'",
)
def enforce_backup_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
limit_gb: int = FEATURES[tier]["backup_gb"]
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup is not available on the '{tier}' tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup quota exceeded for tier '{tier}'",
)
def check_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> bool:
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
return False
if limit_gb == -1:
return True
limit_bytes = limit_gb * 1024 ** 3
return current_bytes + additional_bytes <= limit_bytes
# Module-level singleton shared across the app.
tier_manager = TierManager()

View File

View File

@@ -1,62 +0,0 @@
from typing import Literal
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
JWT_SECRET: str = "change-me-in-production"
JWT_ALGORITHM: str = "HS256"
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 30
STRIPE_SECRET_KEY: str = ""
STRIPE_WEBHOOK_SECRET: str = ""
S3_BUCKET: str = ""
S3_REGION: str = "us-east-1"
S3_ENDPOINT_URL: str = ""
AWS_ACCESS_KEY_ID: str = ""
AWS_SECRET_ACCESS_KEY: str = ""
PINECONE_API_KEY: str = ""
PINECONE_INDEX: str = "adiuva"
QDRANT_URL: str = ""
QDRANT_API_KEY: str = ""
OPENAI_API_KEY: str = ""
ANTHROPIC_API_KEY: str = ""
GOOGLE_API_KEY: str = ""
CEREBRAS_API_KEY: str = ""
GITHUB_TOKEN: str = ""
LLM_MODEL: str = "gpt-4o"
LLM_EMBED_MODEL: str = "text-embedding-3-small"
# GitHub Copilot OAuth token storage directory.
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
GITHUB_COPILOT_TOKEN_DIR: str = ""
# OAuth client credentials — used for Gmail and Microsoft (Outlook/Teams) flows.
GMAIL_CLIENT_ID: str = ""
GMAIL_CLIENT_SECRET: str = ""
MS_CLIENT_ID: str = ""
MS_CLIENT_SECRET: str = ""
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
MS_TENANT_ID: str = "common"
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
OAUTH_ENCRYPTION_KEY: str = ""
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
ENV: Literal["dev", "prod"] = "dev"
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore"
)
settings = Settings()

View File

View File

@@ -1,30 +0,0 @@
"""Minimal agent base types retained for compatibility with batch runners."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
class BaseAgent(ABC):
"""Common base for non-chat agents still using the old base contract."""
def __init__(
self,
user_id: str = "",
shared_memory: dict[str, Any] | None = None,
vector_store_context: list[str] | None = None,
) -> None:
self.user_id = user_id
self.shared_memory: dict[str, Any] = shared_memory or {}
self.vector_store_context: list[str] = vector_store_context or []
@abstractmethod
def get_name(self) -> str: ...
@abstractmethod
def get_description(self) -> str: ...
@property
def skills(self) -> list[str]:
return []

File diff suppressed because it is too large Load Diff

View File

@@ -1,846 +0,0 @@
"""Single-agent runners for home and floating chat contexts."""
from __future__ import annotations
import json
import logging
import re
from datetime import date
from collections.abc import AsyncGenerator
from typing import Any, Literal
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool
from app.agents.note_agent import NOTE_TOOLS
from app.agents.project_agent import PROJECT_TOOLS
from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS
from app.core.llm import get_llm
from app.core.memory_middleware import MemoryMiddleware
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
from app.db import async_session
logger = logging.getLogger(__name__)
FloatingDomainType = Literal["task", "timeline", "project", "node"]
FloatingDomainSection = Literal["task", "timeline", "note"]
_HOME_SINGLE_AGENT_SYSTEM = (
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
"Always use tools for factual data retrieval before answering. "
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
"Return markdown and use tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>. "
"When listing tasks or timelines, each id tag must be on its own line with no prefix/suffix text. "
"Never put titles, priorities, or dates on the same line as <task> or <timeline> tags. "
"For questions about upcoming timelines (e.g. 'prossimi eventi'), include only future items in the current month unless the user asks a different range. "
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
)
_FLOATING_SINGLE_AGENT_SYSTEM = (
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
"Stay focused on the floating scope in context.scope and answer concisely. "
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
"Always use tools for factual data retrieval before answering. "
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
)
_FLOATING_DOMAIN_CLASSIFIER_SYSTEM = (
"You are a strict domain classifier for websocket floating requests. "
"Return ONLY a JSON object with keys: type, id, section. "
"Allowed type values: task, timeline, project, node. "
"Allowed section values: task, timeline, note, or null. "
"Rules: infer from user message intent first; do not blindly trust scope.type. "
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
"If project id is unknown but context.resolved_project_id exists, use it as id. "
"If id is unknown, use null. "
"No markdown, no prose, JSON only."
)
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
def _candidate_tokens(message: str) -> list[str]:
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
return [token for token in tokens if len(token) >= 3]
async def _resolve_project_id_from_message(message: str) -> str | None:
"""Resolve likely project UUID from user message using client project list."""
try:
result = await execute_on_client(action="select", table="projects")
except Exception as exc:
logger.warning("deep_agent: project resolve select failed: %s", exc)
return None
rows = result.get("rows", [])
if not isinstance(rows, list) or not rows:
return None
tokens = _candidate_tokens(message)
scored: list[tuple[int, dict[str, Any]]] = []
for row in rows:
if not isinstance(row, dict):
continue
name = str(row.get("name", "")).lower()
score = sum(1 for token in tokens if token in name)
if score > 0:
scored.append((score, row))
if not scored:
return None
scored.sort(key=lambda item: item[0], reverse=True)
top_score = scored[0][0]
top_rows = [row for score, row in scored if score == top_score]
if len(top_rows) != 1:
return None
project_id = top_rows[0].get("id")
return project_id if isinstance(project_id, str) else None
def _needs_project_resolution(message: str) -> bool:
lowered = message.lower()
return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"])
async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
prepared = dict(context)
if _needs_project_resolution(message):
resolved_project_id = await _resolve_project_id_from_message(message)
if resolved_project_id:
prepared["resolved_project_id"] = resolved_project_id
logger.info("deep_agent: resolved_project_id=%s", resolved_project_id)
return prepared
def _all_tools() -> list[Any]:
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
debug = context.get("_debug")
if isinstance(debug, dict):
request_id = debug.get("request_id")
if isinstance(request_id, str) and request_id:
return request_id
return None
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
sanitized = dict(context)
sanitized.pop("_debug", None)
return sanitized
_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]</\1>")
_TIMELINE_DMY_RE = re.compile(r"(?P<d>\d{2})/(?P<m>\d{2})/(?P<y>\d{4})")
def _is_upcoming_timeline_query(message: str) -> bool:
lowered = message.lower()
has_upcoming = "prossim" in lowered or "upcoming" in lowered or "next" in lowered
has_timeline_topic = any(
token in lowered
for token in ("event", "evento", "eventi", "timeline", "milestone", "scaden")
)
return has_upcoming and has_timeline_topic
def _timeline_date_in_current_month_or_future(dmy: str) -> bool:
match = _TIMELINE_DMY_RE.search(dmy)
if not match:
return True
try:
parsed = date(
int(match.group("y")),
int(match.group("m")),
int(match.group("d")),
)
except ValueError:
return True
today = date.today()
return parsed >= today and parsed.year == today.year and parsed.month == today.month
def _normalize_tagged_list_lines(text: str, message: str) -> str:
if not text:
return text
upcoming_timeline_only = _is_upcoming_timeline_query(message)
output_lines: list[str] = []
for line in text.splitlines():
matches = list(_TAG_LINE_RE.finditer(line))
if not matches:
output_lines.append(line)
continue
had_non_tag_text = _TAG_LINE_RE.sub("", line).strip(" -\t0123456789.*:)")
if not had_non_tag_text and len(matches) == 1:
tag_text = matches[0].group(0)
if (
upcoming_timeline_only
and "<timeline>" in tag_text
and not _timeline_date_in_current_month_or_future(line)
):
continue
output_lines.append(tag_text)
continue
for match in matches:
tag_text = match.group(0)
if (
upcoming_timeline_only
and "<timeline>" in tag_text
and not _timeline_date_in_current_month_or_future(line)
):
continue
output_lines.append(tag_text)
return "\n".join(output_lines)
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
_FLOATING_EMPTY_FALLBACK = "No results found."
def _strip_floating_markup_fragment(text: str) -> str:
if not text:
return text
cleaned = _GENERIC_TAG_RE.sub("", text)
return _BRACKETED_ID_RE.sub("", cleaned)
def _strip_floating_markup(text: str) -> str:
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
if not text:
return text
cleaned = _strip_floating_markup_fragment(text)
# Collapse excessive spaces introduced by tag/id removal while preserving lines.
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
return "\n".join(line for line in lines if line)
def _fallback_from_raw_floating_text(raw_text: str) -> str:
fallback = _strip_floating_markup_fragment(raw_text or "")
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
return fallback or _FLOATING_EMPTY_FALLBACK
class _FloatingStreamSanitizer:
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
def __init__(self) -> None:
self._pending = ""
@staticmethod
def _split_safe_boundary(text: str) -> tuple[str, str]:
boundary = len(text)
last_lt = text.rfind("<")
if last_lt != -1 and ">" not in text[last_lt:]:
boundary = min(boundary, last_lt)
last_lb = text.rfind("[")
if last_lb != -1 and "]" not in text[last_lb:]:
boundary = min(boundary, last_lb)
if boundary == len(text):
return text, ""
return text[:boundary], text[boundary:]
def feed(self, chunk: str) -> str:
combined = f"{self._pending}{chunk}"
safe_text, self._pending = self._split_safe_boundary(combined)
return _strip_floating_markup_fragment(safe_text)
def finalize(self) -> str:
# Drop dangling unfinished wrappers at the very end.
tail = re.sub(r"<[^>\n]*$", "", self._pending)
tail = re.sub(r"\[[^\]\n]*$", "", tail)
self._pending = ""
return _strip_floating_markup_fragment(tail)
def _normalize_memory_label(path_or_label: str) -> str:
value = path_or_label.strip()
if value.startswith("/memories/"):
value = value[len("/memories/"):]
value = value.strip("/")
return value
def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
@tool
async def memory_list_blocks() -> str:
"""List all core memory blocks currently stored for the user."""
logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id)
async with async_session() as db:
memory = MemoryMiddleware(db)
blocks = await memory.list_core_blocks(user_id)
if not blocks:
return "No memory blocks found."
lines = [f"- {b['label']}: {b['value']}" for b in blocks]
return "Memory blocks:\n" + "\n".join(lines)
@tool
async def memory_get(path_or_label: str) -> str:
"""Get one memory block by label or /memories/<label> path."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_get trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
value = await memory.get_core_block(user_id, label)
if value is None:
return f"Memory block '{label}' not found."
return f"Memory block '{label}':\n{value}"
@tool
async def memory_create(path_or_label: str, value: str) -> str:
"""Create or overwrite a memory block value by label or /memories/<label> path."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_create trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.update_core(user_id, label, value, trace_id=trace_id)
return f"Memory block '{label}' saved."
@tool
async def memory_append(path_or_label: str, content: str) -> str:
"""Append content to a memory block, creating it if missing."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_append trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.append_core(user_id, label, content)
return f"Memory block '{label}' appended."
@tool
async def memory_replace(path_or_label: str, old_string: str, new_string: str) -> str:
"""Replace one exact string in a memory block."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_replace trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
changed = await memory.replace_core(user_id, label, old_string, new_string)
if not changed:
return f"No replacement made in '{label}' (old string not found)."
return f"Memory block '{label}' updated."
@tool
async def memory_delete(path_or_label: str) -> str:
"""Delete a memory block by label or /memories/<label> path."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_delete trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
deleted = await memory.delete_core(user_id, label)
if not deleted:
return f"Memory block '{label}' not found."
return f"Memory block '{label}' deleted."
@tool
async def archival_memory_insert(content: str) -> str:
"""Insert a long-term archival memory entry."""
logger.info("deep_agent: archival_memory_insert trace=%s user=%s", trace_id or "-", user_id)
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.insert_archival(user_id, content, source="assistant")
return "Archival memory saved."
@tool
async def archival_memory_search(query: str, top_k: int = 5) -> str:
"""Search long-term archival memory by semantic fallback (keyword currently)."""
logger.info("deep_agent: archival_memory_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
async with async_session() as db:
memory = MemoryMiddleware(db)
results = await memory.search_archival(user_id, query, top_k=top_k)
if not results:
return "No archival memory results found."
lines = [f"- {item}" for item in results]
return "Archival memory results:\n" + "\n".join(lines)
@tool
async def conversation_search(query: str, top_k: int = 5) -> str:
"""Search recall memory from prior episodic conversation summaries."""
logger.info("deep_agent: conversation_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
async with async_session() as db:
memory = MemoryMiddleware(db)
results = await memory.search_recall(user_id, query, top_k=top_k)
if not results:
return "No recall memory results found."
lines = [f"- {item}" for item in results]
return "Recall memory results:\n" + "\n".join(lines)
return [
memory_list_blocks,
memory_get,
memory_create,
memory_append,
memory_replace,
memory_delete,
archival_memory_insert,
archival_memory_search,
conversation_search,
]
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
def _detect_domain_section(message: str) -> FloatingDomainSection | None:
lowered = message.lower()
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
return "timeline"
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
return "task"
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
return "note"
return None
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
type_raw = str(payload.get("type") or "").strip().lower()
domain_type: FloatingDomainType = "task"
if type_raw in {"task", "timeline", "project", "node"}:
domain_type = type_raw
id_value = payload.get("id")
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
if domain_type == "project" and not domain_id:
domain_id = fallback_id
section_raw = payload.get("section")
section: FloatingDomainSection | None = None
if isinstance(section_raw, str):
section_candidate = section_raw.strip().lower()
if section_candidate in {"task", "timeline", "note"}:
section = section_candidate
if domain_type != "project":
section = None
return {
"type": domain_type,
"id": domain_id,
"section": section,
}
def _parse_json_object(text: str) -> dict[str, Any] | None:
raw = text.strip()
if not raw:
return None
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", raw, re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
except json.JSONDecodeError:
return None
return parsed if isinstance(parsed, dict) else None
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
section = _detect_domain_section(message)
scope = context.get("scope") if isinstance(context, dict) else None
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
if isinstance(scope, dict):
scope_type = str(scope.get("type") or "").strip().lower()
scope_id = scope.get("id")
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
if scope_type in {"task", "tasks"}:
return {"type": "task", "id": scope_id_value, "section": None}
if scope_type in {"project", "projects"}:
project_scope_id = scope_id_value or project_id
return {
"type": "project",
"id": project_scope_id,
"section": section,
}
if scope_type in {"note", "notes"}:
return {
"type": "node",
"id": scope_id_value,
"section": None,
}
if scope_type in {"timeline", "timelines"}:
return {"type": "timeline", "id": scope_id_value, "section": None}
lowered = message.lower()
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
return {
"type": "project",
"id": project_id,
"section": section,
}
if section == "timeline":
return {"type": "timeline", "id": None, "section": None}
if section == "note":
return {"type": "node", "id": None, "section": None}
return {"type": "task", "id": None, "section": None}
async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[str, str | None]:
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
classifier_context = {
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
"resolved_project_id": project_id,
}
try:
llm = get_llm()
response = await llm.ainvoke(
[
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_SYSTEM),
HumanMessage(
content=(
f"Message:\n{message}\n\n"
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
)
),
]
)
parsed = _parse_json_object(_as_text(response.content))
if parsed is not None:
domain = _normalize_domain_payload(parsed, project_id)
logger.info(
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
domain.get("type"),
domain.get("id"),
domain.get("section"),
)
return domain
logger.warning("deep_agent: floating_domain classifier returned non-json output")
except Exception as exc:
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
return _infer_floating_domain_rule_based(message, context)
async def _run_single_agent(
*,
user_id: str,
system_prompt: str,
message: str,
context: dict[str, Any],
max_steps: int = 6,
) -> str:
trace_id = _trace_id_from_context(context)
llm = get_llm()
tools = _all_tools_for_user(user_id, trace_id)
model_context = _context_for_model(context)
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(
content=(
f"User message:\n{message}\n\n"
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
)
),
]
tool_calls_count = 0
collected: list[dict[str, Any]] = []
set_tool_result_collector(collected)
try:
for _ in range(max_steps):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
final_text = _as_text(response.content)
logger.info(
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d",
trace_id or "-",
user_id,
tool_calls_count,
len(final_text),
)
return final_text
tool_map = {tool_def.name: tool_def for tool_def in tools}
for call in response.tool_calls:
tool_calls_count += 1
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
call_id,
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
call_id,
call_name,
str(tool_output)[:1200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
final = await llm.ainvoke(messages)
final_text = _as_text(final.content)
logger.info(
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
trace_id or "-",
user_id,
tool_calls_count,
len(final_text),
)
return final_text
finally:
clear_tool_result_collector()
async def _run_single_agent_stream(
*,
user_id: str,
system_prompt: str,
message: str,
context: dict[str, Any],
max_steps: int = 6,
) -> AsyncGenerator[tuple[str, Any], None]:
trace_id = _trace_id_from_context(context)
llm = get_llm()
tools = _all_tools_for_user(user_id, trace_id)
model_context = _context_for_model(context)
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(
content=(
f"User message:\n{message}\n\n"
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
)
),
]
tool_calls_count = 0
streamed_chars = 0
collected: list[dict[str, Any]] = []
set_tool_result_collector(collected)
try:
for _ in range(max_steps):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
emitted_any = False
async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", ""))
if token:
streamed_chars += len(token)
emitted_any = True
yield "token", token
# Some providers return final text in `response.content` but stream no chunks.
if not emitted_any:
fallback_text = _as_text(response.content)
if fallback_text:
streamed_chars += len(fallback_text)
yield "token", fallback_text
logger.info(
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
trace_id or "-",
user_id,
tool_calls_count,
streamed_chars,
)
return
tool_map = {tool_def.name: tool_def for tool_def in tools}
for call in response.tool_calls:
tool_calls_count += 1
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
call_id,
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
call_id,
call_name,
str(tool_output)[:1200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", ""))
if token:
streamed_chars += len(token)
yield "token", token
logger.info(
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
trace_id or "-",
user_id,
tool_calls_count,
streamed_chars,
)
finally:
clear_tool_result_collector()
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
prepared_context = await _prepare_context(message, context)
response = await _run_single_agent(
user_id=user_id,
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
message=message,
context=prepared_context,
)
return _normalize_tagged_list_lines(response, message)
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context)
response = await _run_single_agent(
user_id=user_id,
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
message=message,
context=prepared_context,
)
sanitized = _strip_floating_markup(response)
if not sanitized and response:
sanitized = _fallback_from_raw_floating_text(response)
return sanitized, domain
async def run_home_stream(
user_id: str,
message: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
prepared_context = await _prepare_context(message, context)
text_chunks: list[str] = []
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
message=message,
context=prepared_context,
):
event_type, data = event
if event_type != "token":
yield event
continue
text_chunks.append(str(data or ""))
normalized = _normalize_tagged_list_lines("".join(text_chunks), message)
if normalized:
yield "token", normalized
async def run_floating_stream(
user_id: str,
message: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context)
yield "floating_domain", domain
sanitizer = _FloatingStreamSanitizer()
emitted_sanitized = False
raw_chunks: list[str] = []
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
message=message,
context=prepared_context,
):
event_type, data = event
if event_type != "token":
yield event
continue
raw_chunk = str(data or "")
raw_chunks.append(raw_chunk)
sanitized_chunk = sanitizer.feed(raw_chunk)
if sanitized_chunk:
emitted_sanitized = True
yield "token", sanitized_chunk
tail = sanitizer.finalize()
if tail:
emitted_sanitized = True
yield "token", tail
if not emitted_sanitized and raw_chunks:
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
async def update_core_memory(user_id: str, key: str, value: str) -> None:
"""Compatibility helper kept for callers that expect explicit memory update API."""
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.update_core(user_id, key, value)

View File

@@ -1,151 +0,0 @@
"""Device connection manager.
Maintains in-memory state for all active Electron → backend WebSocket
connections. One connection per user (latest replaces previous).
The manager handles the **tool-call round-trip** pattern:
- Backend sends ``tool_call`` frame → Electron executes the action →
returns ``tool_result`` frame.
- ``create_pending_call`` registers a Future keyed by ``call_id``.
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
receive the result dict from Electron.
This pattern is used by all tools (CRUD, file-system, etc.) via
``execute_on_client()`` in ``ws_context.py``.
The ``device_manager`` module-level singleton is imported by both the
device WS route and the agent runner.
"""
from __future__ import annotations
import asyncio
import json
import logging
from dataclasses import dataclass, field
from fastapi import WebSocket
logger = logging.getLogger(__name__)
@dataclass
class DeviceConnection:
"""State for a single connected Electron device."""
ws: WebSocket
device_id: str
# Futures indexed by tool_call id — resolved when tool_result arrives.
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
class DeviceConnectionManager:
"""Singleton registry of active Electron WebSocket connections.
Thread/task safety note: asyncio is single-threaded by design. All
mutations happen inside await-points on the main event loop, so no
locking is required for the in-memory dicts.
"""
def __init__(self) -> None:
self._connections: dict[str, DeviceConnection] = {}
# ── Registration ──────────────────────────────────────────────────
def register(self, user_id: str, device_id: str, ws: WebSocket) -> None:
"""Store the active connection for *user_id*, replacing any previous one."""
if user_id in self._connections:
old = self._connections[user_id]
logger.info(
"device_manager: replacing existing connection for user=%s device=%s",
user_id,
old.device_id,
)
# Cancel any futures that were waiting on the old connection.
for fut in old.pending_calls.values():
if not fut.done():
fut.cancel()
self._connections[user_id] = DeviceConnection(ws=ws, device_id=device_id)
logger.info(
"device_manager: registered user=%s device=%s", user_id, device_id
)
def unregister(self, user_id: str) -> None:
"""Remove the connection for *user_id* and cancel any pending futures."""
conn = self._connections.pop(user_id, None)
if conn is None:
return
for fut in conn.pending_calls.values():
if not fut.done():
fut.cancel()
logger.info("device_manager: unregistered user=%s", user_id)
# ── Presence queries ──────────────────────────────────────────────
def get_ws(self, user_id: str) -> WebSocket | None:
"""Return the active WebSocket for *user_id*, or ``None`` if offline."""
conn = self._connections.get(user_id)
return conn.ws if conn else None
def is_online(self, user_id: str, device_id: str | None = None) -> bool:
"""Return ``True`` if the user has an active connection.
If *device_id* is provided also checks that it matches the connected device.
"""
conn = self._connections.get(user_id)
if conn is None:
return False
if device_id is not None:
return conn.device_id == device_id
return True
# ── Frame sending ─────────────────────────────────────────────────
async def send_frame(self, user_id: str, frame: dict) -> None:
"""Send *frame* as a JSON text message to the device.
Raises ``RuntimeError`` if the user is not connected.
"""
conn = self._connections.get(user_id)
if conn is None:
raise RuntimeError(
f"send_frame: user {user_id!r} is not connected"
)
await conn.ws.send_text(json.dumps(frame))
# ── Tool-call round-trip ──────────────────────────────────────────
def create_pending_call(
self, user_id: str, call_id: str
) -> asyncio.Future[dict]:
"""Register a Future that will be resolved when the tool_result arrives.
Raises ``RuntimeError`` if the user is not connected.
"""
conn = self._connections.get(user_id)
if conn is None:
raise RuntimeError(
f"create_pending_call: user {user_id!r} is not connected"
)
loop = asyncio.get_event_loop()
fut: asyncio.Future[dict] = loop.create_future()
conn.pending_calls[call_id] = fut
return fut
def resolve_pending_call(
self, user_id: str, call_id: str, result: dict
) -> None:
"""Fulfil the Future registered under *call_id* with the Electron result.
No-ops if the call_id is unknown (already timed out or cancelled).
"""
conn = self._connections.get(user_id)
if conn is None:
return
fut = conn.pending_calls.pop(call_id, None)
if fut is not None and not fut.done():
fut.set_result(result)
# Module-level singleton — import this everywhere.
device_manager = DeviceConnectionManager()

View File

@@ -1,122 +0,0 @@
"""LLM factory — centralised model instantiation via LiteLLM.
Every agent and the orchestrator call ``get_llm()``
instead of directly constructing a provider-specific class. The model string
follows the `LiteLLM model naming convention
<https://docs.litellm.ai/docs/providers>`_:
* OpenAI: ``gpt-4o``, ``gpt-4o-mini``
* Anthropic: ``anthropic/claude-3.5-sonnet``
* Google: ``gemini/gemini-pro``
* Ollama: ``ollama/llama3``
* Bedrock: ``bedrock/anthropic.claude-v2``
Switch providers by changing **LLM_MODEL** in ``.env``
— no code changes required.
"""
from __future__ import annotations
import os
import warnings
from openai import AsyncOpenAI
import litellm
from langchain_openai import ChatOpenAI
from langchain_litellm import ChatLiteLLM
from litellm import get_supported_openai_params # noqa: F401 validates install
from app.config.settings import settings
# Some models (e.g. gpt-5, o-series) reject unsupported params like temperature.
# Drop them silently instead of raising UnsupportedParamsError.
litellm.drop_params = True
# Some provider responses include a plain dict in the `usage` field where a
# richer Pydantic model is expected. This warning is noisy but non-fatal.
warnings.filterwarnings(
"ignore",
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
category=UserWarning,
)
def _api_key_for_model(model: str) -> str | None:
"""Return the most appropriate API key for the given LiteLLM model string."""
if model.startswith("anthropic/"):
return settings.ANTHROPIC_API_KEY or None
if model.startswith("gemini/") or model.startswith("google/"):
return settings.GOOGLE_API_KEY or None
if model.startswith("cerebras/"):
return settings.CEREBRAS_API_KEY or None
if model.startswith("github/"):
return settings.GITHUB_TOKEN or None
if model.startswith("github_copilot/"):
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
# No API key is required; returning None lets LiteLLM handle auth.
return None
# Default: OpenAI-compatible (covers plain model names like "gpt-4o")
return settings.OPENAI_API_KEY or None
def get_llm(
*,
model: str | None = None,
temperature: float = 0,
) -> ChatOpenAI | ChatLiteLLM:
"""Return a LangChain chat model backed by LiteLLM.
LiteLLM exposes an OpenAI-compatible API, so we use ``ChatOpenAI`` pointed
at the LiteLLM proxy endpoint. In practice, ``litellm`` patches the
``openai`` client transparently when the model string contains a provider
prefix (``anthropic/…``, ``gemini/…``, etc.).
Parameters
----------
model:
LiteLLM model identifier. Defaults to ``settings.LLM_MODEL``.
temperature:
Sampling temperature. ``0`` = deterministic.
"""
model = model or settings.LLM_MODEL
# Point LiteLLM to the custom token directory when configured.
if settings.GITHUB_COPILOT_TOKEN_DIR:
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
if settings.GITHUB_TOKEN:
os.environ.setdefault("GITHUB_TOKEN", settings.GITHUB_TOKEN)
# Use ChatLiteLLM for provider-prefixed models (github_copilot/, anthropic/, etc.)
# so LiteLLM handles routing and auth. ChatOpenAI for plain OpenAI model names.
if "/" in model:
return ChatLiteLLM(model=model, temperature=temperature)
return ChatOpenAI(
model=model,
temperature=temperature,
api_key=_api_key_for_model(model),
)
async def embed(text: str) -> list[float]:
"""Return an embedding vector for *text*.
Uses ``settings.LLM_EMBED_MODEL`` so the same provider switch in ``.env``
(e.g. ``github_copilot/text-embedding-3-small``) applies here without any
code changes. Falls back to the raw AsyncOpenAI client for plain OpenAI
model names to preserve existing behaviour.
"""
model = settings.LLM_EMBED_MODEL
if model.startswith("github_copilot/") or "/" in model:
# Use LiteLLM for all provider-prefixed models (Copilot, Bedrock, etc.)
# so the provider's auth mechanism is applied correctly.
response = await litellm.aembedding(model=model, input=[text])
return response.data[0]["embedding"]
# Plain OpenAI model name — use the raw AsyncOpenAI client (existing path).
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
response = await client.embeddings.create(model=model, input=text)
return response.data[0].embedding

View File

@@ -1,441 +0,0 @@
"""Memory Middleware — enrich requests with memory context and store interactions.
Four-tier memory model (MemGPT-style):
core — persistent key/value user preferences, always injected
associative — semantic similarity search via pgvector (top-k)
episodic — recent session summaries (last N)
proactive — behavioral patterns above confidence threshold
All memory content is encrypted at rest using the per-user Fernet key
stored in User.encryption_key. Decryption happens in-memory only.
Usage:
memory = MemoryMiddleware(db_session)
context = await memory.enrich_context(user_id, message)
# ... run agent ...
await memory.store_episode(user_id, session_id, message, response)
"""
from __future__ import annotations
import logging
import uuid
from typing import Any
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import (
MemoryAssociative,
MemoryCore,
MemoryEpisodic,
MemoryProactive,
User,
)
logger = logging.getLogger(__name__)
# Tuning constants
_ASSOCIATIVE_TOP_K = 5
_EPISODIC_RECENT_N = 10
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
class MemoryMiddleware:
"""Enrich orchestrator context with memory and persist interactions after."""
def __init__(self, db: AsyncSession) -> None:
self._db = db
# ── Public API ────────────────────────────────────────────────────────────
async def enrich_context(
self,
user_id: str,
message: str,
trace_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Build memory context dict to inject into the orchestrator before LLM call.
Returns a dict with keys:
core_memory — {key: plaintext_value, ...}
associative_memory — [plaintext_content, ...] (top-k by keyword match)
episodic_memory — [plaintext_summary, ...] (most recent N)
proactive_hints — [plaintext_pattern, ...] (above threshold)
"""
fernet = await self._get_fernet(user_id)
if fernet is None:
return {}
core = await self._load_core(user_id, fernet)
associative = await self._load_associative(user_id, message, fernet)
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
proactive = await self._load_proactive(user_id, fernet)
user_dbg = await self._get_user_debug(user_id)
logger.info(
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
trace_id or "-",
user_id,
user_dbg.get("tier") or "-",
len(core),
len(associative),
len(episodic),
len(proactive),
)
return {
"core_memory": core,
"associative_memory": associative,
"episodic_memory": episodic,
"proactive_hints": proactive,
}
async def store_episode(
self,
user_id: str,
session_id: str,
message: str,
response: str,
trace_id: str | None = None,
) -> None:
"""Summarise and store a completed interaction in episodic memory.
The summary is a simple heuristic concatenation (no LLM call) to keep
latency low. Full LLM summarisation can be added in a later step.
"""
fernet = await self._get_fernet(user_id)
if fernet is None:
return
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
encrypted = _encrypt(fernet, summary)
row = MemoryEpisodic(
id=str(uuid.uuid4()),
user_id=user_id,
summary_encrypted=encrypted,
session_id=session_id,
)
self._db.add(row)
try:
await self._db.commit()
user_dbg = await self._get_user_debug(user_id)
logger.info(
"memory: store_episode trace=%s user=%s tier=%s session=%s",
trace_id or "-",
user_id,
user_dbg.get("tier") or "-",
session_id,
)
except Exception as exc:
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
"""Upsert a core memory key/value for a user."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return
encrypted = _encrypt(fernet, value)
result = await self._db.execute(
select(MemoryCore).where(
MemoryCore.user_id == user_id,
MemoryCore.key == key,
)
)
existing = result.scalar_one_or_none()
if existing is not None:
existing.value_encrypted = encrypted
else:
self._db.add(MemoryCore(
id=str(uuid.uuid4()),
user_id=user_id,
key=key,
value_encrypted=encrypted,
))
try:
await self._db.commit()
user_dbg = await self._get_user_debug(user_id)
logger.info(
"memory: update_core trace=%s user=%s tier=%s key=%s",
trace_id or "-",
user_id,
user_dbg.get("tier") or "-",
key,
)
except Exception as exc:
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
await self._db.rollback()
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
"""Return core memory as editable blocks (label/value)."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryCore)
.where(MemoryCore.user_id == user_id)
.order_by(MemoryCore.key.asc())
)
rows = result.scalars().all()
out: list[dict[str, str]] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.value_encrypted)
if plaintext is not None:
out.append({"label": row.key, "value": plaintext})
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
return out
async def get_core_block(self, user_id: str, label: str) -> str | None:
"""Return a single core memory block value by label."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return None
result = await self._db.execute(
select(MemoryCore).where(
MemoryCore.user_id == user_id,
MemoryCore.key == label,
)
)
row = result.scalar_one_or_none()
if row is None:
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
return None
value = _safe_decrypt(fernet, row.value_encrypted)
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
return value
async def delete_core(self, user_id: str, label: str) -> bool:
"""Delete a core memory block by label. Returns True if deleted."""
result = await self._db.execute(
select(MemoryCore).where(
MemoryCore.user_id == user_id,
MemoryCore.key == label,
)
)
row = result.scalar_one_or_none()
if row is None:
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
return False
await self._db.delete(row)
try:
await self._db.commit()
logger.info("memory: delete_core user=%s label=%s", user_id, label)
return True
except Exception as exc:
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
await self._db.rollback()
return False
async def append_core(self, user_id: str, label: str, content: str) -> None:
"""Append content to a core block, creating it if missing."""
current = await self.get_core_block(user_id, label)
if current is None:
await self.update_core(user_id, label, content)
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
return
await self.update_core(user_id, label, f"{current}\n{content}")
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
"""Replace one exact string inside a core block. Returns False if not found."""
current = await self.get_core_block(user_id, label)
if current is None or old not in current:
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
return False
await self.update_core(user_id, label, current.replace(old, new, 1))
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
return True
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
"""Insert a long-term archival memory entry."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return
encrypted = _encrypt(fernet, content)
row = MemoryAssociative(
id=str(uuid.uuid4()),
user_id=user_id,
content_encrypted=encrypted,
embedding=None,
entity_type=source,
entity_id=None,
)
self._db.add(row)
try:
await self._db.commit()
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
except Exception as exc:
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryAssociative)
.where(MemoryAssociative.user_id == user_id)
.order_by(MemoryAssociative.updated_at.desc())
.limit(100)
)
rows = result.scalars().all()
needle = query.strip().lower()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is None:
continue
if not needle or needle in plaintext.lower():
out.append(plaintext)
if len(out) >= max(top_k, 1):
break
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
return out
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
"""Search recall memory (episodic summaries) by keyword."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryEpisodic)
.where(MemoryEpisodic.user_id == user_id)
.order_by(MemoryEpisodic.created_at.desc())
.limit(100)
)
rows = result.scalars().all()
needle = query.strip().lower()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
if plaintext is None:
continue
if not needle or needle in plaintext.lower():
out.append(plaintext)
if len(out) >= max(top_k, 1):
break
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
return out
# ── Private helpers ───────────────────────────────────────────────────────
async def _get_fernet(self, user_id: str) -> Fernet | None:
"""Load the user's Fernet key from DB. Returns None if missing."""
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None or not user.encryption_key:
logger.warning("memory: no encryption_key for user=%s", user_id)
return None
return Fernet(user.encryption_key.encode())
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
"""Load lightweight user debug fields for trace logs."""
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
return {"tier": None}
return {
"tier": user.tier,
}
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
result = await self._db.execute(
select(MemoryCore).where(MemoryCore.user_id == user_id)
)
rows = result.scalars().all()
out: dict[str, str] = {}
for row in rows:
plaintext = _safe_decrypt(fernet, row.value_encrypted)
if plaintext is not None:
out[row.key] = plaintext
return out
async def _load_associative(
self, user_id: str, message: str, fernet: Fernet
) -> list[str]:
"""Load top-k associative memories.
Production: uses pgvector cosine similarity on the message embedding.
Current implementation: keyword-based fallback (no external embedding call)
so tests pass without a live OpenAI key.
"""
result = await self._db.execute(
select(MemoryAssociative)
.where(MemoryAssociative.user_id == user_id)
.order_by(MemoryAssociative.updated_at.desc())
.limit(_ASSOCIATIVE_TOP_K)
)
rows = result.scalars().all()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
async def _load_episodic(
self,
user_id: str,
fernet: Fernet,
session_id: str | None = None,
) -> list[str]:
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
if session_id:
query = query.where(MemoryEpisodic.session_id == session_id)
result = await self._db.execute(
query
.order_by(MemoryEpisodic.created_at.desc())
.limit(_EPISODIC_RECENT_N)
)
rows = result.scalars().all()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
result = await self._db.execute(
select(MemoryProactive)
.where(
MemoryProactive.user_id == user_id,
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
)
.order_by(MemoryProactive.confidence.desc())
)
rows = result.scalars().all()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
# ── Encryption helpers ────────────────────────────────────────────────────────
def _encrypt(fernet: Fernet, plaintext: str) -> str:
return fernet.encrypt(plaintext.encode()).decode()
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
"""Decrypt and return plaintext, or None on error (corrupted/wrong key)."""
try:
return fernet.decrypt(ciphertext.encode()).decode()
except (InvalidToken, Exception) as exc:
logger.warning("memory: decrypt failed: %s", exc)
return None

View File

@@ -1,47 +0,0 @@
"""Output formatter for deep-agent stream events."""
from __future__ import annotations
from collections.abc import AsyncGenerator
from typing import Any
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
class StreamFormatter:
"""Convert `(event_type, data)` stream events into websocket frame models."""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
async def format(
self,
event_stream: AsyncGenerator[tuple[str, Any], None],
) -> AsyncGenerator[WsFrame, None]:
started = False
async for event_type, data in event_stream:
if event_type == "floating_domain":
if isinstance(data, dict):
yield WsFloatingDomain(
request_id=self.request_id,
domain=data,
)
continue
if event_type != "token":
continue
if not started:
yield WsStreamStart(request_id=self.request_id)
started = True
text = str(data or "")
if text:
yield WsStreamText(request_id=self.request_id, chunk=text)
if not started:
yield WsStreamStart(request_id=self.request_id)
yield WsStreamEnd(request_id=self.request_id)

View File

@@ -1,92 +0,0 @@
"""WebSocket client executor context.
Holds a per-request async callback that tools call to execute CRUD
operations on the Electron client's local SQLite / LanceDB databases.
The callback sends a `tool_call` WS frame and awaits the `tool_result`.
"""
from __future__ import annotations
from contextvars import ContextVar
from typing import Any, Callable, Coroutine
from uuid import uuid4
# Holds the execute callback for the current WS session.
# Set by the chat WS handler before the orchestrator runs; cleared after.
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
"_client_executor"
)
# Optional collector that captures raw execute_on_client results.
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
"_tool_result_collector", default=None
)
def set_tool_result_collector(lst: list[dict]) -> None:
"""Register *lst* as the collector for this async context."""
_tool_result_collector.set(lst)
def clear_tool_result_collector() -> None:
"""Clear the collector (best-effort)."""
_tool_result_collector.set(None)
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
_client_executor.set(fn)
def clear_client_executor() -> None:
"""Remove the executor binding (best-effort; ContextVar resets on task exit)."""
try:
_client_executor.set(None) # type: ignore[arg-type]
except Exception:
pass
async def execute_on_client(
action: str,
table: str | None = None,
data: dict[str, Any] | None = None,
filters: dict[str, Any] | None = None,
vector: list[float] | None = None,
limit: int | None = None,
) -> dict[str, Any]:
"""Send a CRUD/vector operation to the Electron client and return the result.
Builds a ``tool_call`` payload, invokes the per-session WS callback,
and returns the ``tool_result`` dict from Electron.
Raises ``RuntimeError`` if no executor is set (i.e. called outside a WS session).
"""
callback = _client_executor.get(None)
if callback is None:
raise RuntimeError(
"execute_on_client() called outside a WebSocket session — "
"no client executor is set."
)
payload: dict[str, Any] = {"id": str(uuid4()), "action": action}
if table is not None:
payload["table"] = table
if data is not None:
payload["data"] = data
if filters is not None:
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
if vector is not None:
payload["vector"] = vector
if limit is not None:
payload["limit"] = limit
result = await callback(payload)
collector = _tool_result_collector.get(None)
if collector is not None:
collector.append({
"action": action,
"table": table,
"data": result,
})
return result

View File

@@ -1,40 +0,0 @@
"""Database engine, session factory, and base model.
All app code uses the async SQLAlchemy API. Alembic migrations use the
synchronous psycopg2 URL for the CLI (see alembic/env.py).
Usage in routes:
from app.db import get_session
from sqlalchemy.ext.asyncio import AsyncSession
async def my_route(db: AsyncSession = Depends(get_session)):
result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none()
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from app.config.settings import settings
engine = create_async_engine(
settings.DATABASE_URL,
pool_pre_ping=True,
echo=False,
)
async_session = async_sessionmaker(engine, expire_on_commit=False)
class Base(DeclarativeBase):
"""Shared declarative base for all ORM models."""
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""FastAPI dependency that yields an async DB session per request."""
async with async_session() as session:
yield session

View File

@@ -1,164 +0,0 @@
"""Cloud provider integration utilities.
Provides:
* Shared message dataclasses (``EmailMessage``, ``ChatMessage``) used by
both the Gmail and MS Graph clients and consumed by ``agent_runner``.
* ``get_provider()`` — factory that returns the correct client given a
provider name and decrypted OAuth credentials dict.
* ``encrypt_token()`` / ``decrypt_token()`` — Fernet-based at-rest
encryption for OAuth tokens stored in ``cloud_agent_configs``.
Encryption rationale
--------------------
Unlike user content (which is E2E-encrypted client-side and **never**
decrypted server-side), OAuth tokens *must* be decrypted server-side
because the backend makes provider API calls on behalf of the user.
The Fernet key lives solely in ``OAUTH_ENCRYPTION_KEY`` env var — it
is never returned to clients.
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING
from cryptography.fernet import Fernet, InvalidToken
from app.config.settings import settings
if TYPE_CHECKING:
from app.integrations.gmail import GmailClient
from app.integrations.ms_graph import MSGraphClient
logger = logging.getLogger(__name__)
# ── Shared message types ──────────────────────────────────────────────────
@dataclass
class EmailMessage:
"""A single email message fetched from Gmail or Outlook."""
id: str
subject: str
sender: str
body_text: str
date: datetime
labels: list[str] = field(default_factory=list)
@property
def as_text(self) -> str:
"""Return a human-readable text representation for LLM extraction."""
date_str = self.date.strftime("%Y-%m-%d %H:%M")
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
return (
f"From: {self.sender}\n"
f"Date: {date_str}{labels_str}\n"
f"Subject: {self.subject}\n\n"
f"{self.body_text}"
)
@dataclass
class ChatMessage:
"""A single Teams chat or channel message fetched from MS Graph."""
id: str
content: str
sender: str
channel: str | None
date: datetime
@property
def as_text(self) -> str:
"""Return a human-readable text representation for LLM extraction."""
date_str = self.date.strftime("%Y-%m-%d %H:%M")
channel_str = f" [channel: {self.channel}]" if self.channel else ""
return (
f"From: {self.sender}\n"
f"Date: {date_str}{channel_str}\n\n"
f"{self.content}"
)
# ── Fernet helpers ────────────────────────────────────────────────────────
def _get_fernet() -> Fernet:
"""Return a ``Fernet`` instance using ``settings.OAUTH_ENCRYPTION_KEY``.
Raises ``RuntimeError`` if ``OAUTH_ENCRYPTION_KEY`` is not set — callers
must ensure this is configured before persisting OAuth tokens.
"""
key = settings.OAUTH_ENCRYPTION_KEY
if not key:
raise RuntimeError(
"OAUTH_ENCRYPTION_KEY is not set. "
"Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
)
return Fernet(key.encode() if isinstance(key, str) else key)
def encrypt_token(token_info: dict) -> str:
"""Fernet-encrypt an OAuth credential dict and return a base64 string.
Stores the full ``{access_token, refresh_token, token_uri, client_id,
client_secret, scopes, expiry}`` dict (or equivalent MSAL shape).
Raises:
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
ValueError: ``token_info`` is not a non-empty dict.
"""
if not isinstance(token_info, dict) or not token_info:
raise ValueError("token_info must be a non-empty dict")
plaintext = json.dumps(token_info).encode("utf-8")
return _get_fernet().encrypt(plaintext).decode("utf-8")
def decrypt_token(encrypted: str) -> dict:
"""Decrypt a Fernet-encrypted token string and return the credential dict.
Raises:
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
ValueError: The encrypted string is invalid or was encrypted with a
different key.
"""
try:
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
return json.loads(plaintext)
except (InvalidToken, json.JSONDecodeError) as exc:
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
# ── Provider factory ──────────────────────────────────────────────────────
def get_provider(
provider: str,
credentials_info: dict,
) -> "GmailClient | MSGraphClient":
"""Return the correct provider client for *provider*.
Parameters
----------
provider:
One of ``"gmail"``, ``"outlook"``, ``"teams"``.
credentials_info:
Decrypted OAuth credential dict (Google or Microsoft shape).
Raises:
ValueError: Unknown provider name.
"""
if provider == "gmail":
from app.integrations.gmail import GmailClient
return GmailClient(credentials_info)
if provider in {"outlook", "teams"}:
from app.integrations.ms_graph import MSGraphClient
return MSGraphClient(credentials_info)
raise ValueError(
f"Unknown cloud provider {provider!r}. "
"Supported: 'gmail', 'outlook', 'teams'."
)

View File

@@ -1,335 +0,0 @@
"""Gmail API client for cloud agent integration.
Wraps the Google Gmail REST API to fetch email messages matching a
``filter_config`` dict. Uses the official ``google-api-python-client``
library (synchronous) wrapped in ``asyncio.to_thread()`` to avoid
blocking the event loop.
Token refresh is handled transparently: when the stored access token has
expired, ``google.auth.transport.requests.Request`` will use the refresh
token to obtain a fresh one. The caller is responsible for persisting
any refreshed credentials back to ``CloudAgentConfig.oauth_token_encrypted``
(see ``agent_runner.run_cloud_agent``).
Credential dict shape (Google OAuth2):
{
"token": "<access_token>",
"refresh_token": "<refresh_token>",
"token_uri": "https://oauth2.googleapis.com/token",
"client_id": "<client_id>",
"client_secret": "<client_secret>",
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
"expiry": "2025-01-01T00:00:00Z" # optional ISO-8601
}
"""
from __future__ import annotations
import asyncio
import base64
import email
import html
import logging
import re
from datetime import datetime, timezone
from typing import Any
from app.integrations import EmailMessage
logger = logging.getLogger(__name__)
# Gmail search date format — e.g. "after:2025/01/01"
_GMAIL_DATE_FMT = "%Y/%m/%d"
# Maximum characters of body text forwarded to the LLM.
_BODY_TRUNCATE = 8_000
# Maximum messages retrieved per run (prevents runaway quota usage).
_MAX_MESSAGES = 200
def _build_gmail_query(
filter_config: dict[str, Any] | None,
since: datetime | None,
) -> str:
"""Build a Gmail search query string from *filter_config* and *since*.
Supported ``filter_config`` keys:
labels (list[str]): Gmail label names, e.g. ``["INBOX", "work"]``
senders (list[str]): Sender addresses or domains to include
date_range (dict): ``{from: "<YYYY-MM-DD>", to: "<YYYY-MM-DD>"}``
A hard ``since`` date (from last run) always overrides ``date_range.from``
when it is earlier.
"""
parts: list[str] = []
cfg = filter_config or {}
# Labels — joined with OR when multiple given.
labels: list[str] = cfg.get("labels", [])
if labels:
if len(labels) == 1:
parts.append(f"label:{labels[0]}")
else:
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
parts.append(f"({label_expr})")
# Senders — each prefixed with "from:".
senders: list[str] = cfg.get("senders", [])
for sender in senders:
parts.append(f"from:{sender}")
# Date range.
date_range: dict = cfg.get("date_range", {})
from_str: str | None = date_range.get("from")
to_str: str | None = date_range.get("to")
# Determine effective "from" date: most recent of filter_config.date_range.from and since.
effective_since: datetime | None = since
if from_str:
try:
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
if cfg_since.tzinfo is None:
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
if effective_since is None or cfg_since > effective_since:
effective_since = cfg_since
except ValueError:
logger.warning("gmail: invalid date_range.from %r — ignoring", from_str)
if effective_since:
parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}")
if to_str:
try:
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}")
except ValueError:
logger.warning("gmail: invalid date_range.to %r — ignoring", to_str)
return " ".join(parts)
def _strip_html(raw_html: str) -> str:
"""Remove HTML tags and decode entities to get plain text."""
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
decoded = html.unescape(no_tags)
return re.sub(r"\s+", " ", decoded).strip()
def _parse_body(payload: dict[str, Any]) -> str:
"""Recursively extract the plain-text body from a Gmail message payload.
Prefers ``text/plain``; falls back to ``text/html`` (stripped of tags).
Returns an empty string if no body can be extracted.
"""
mime_type: str = payload.get("mimeType", "")
body: dict = payload.get("body", {})
parts: list[dict] = payload.get("parts", [])
if mime_type == "text/plain":
data = body.get("data", "")
if data:
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
return ""
if mime_type == "text/html":
data = body.get("data", "")
if data:
raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
return _strip_html(raw)
return ""
# Multipart — prefer text/plain part, fall back to text/html.
plain_fallback = ""
for part in parts:
part_mime = part.get("mimeType", "")
if part_mime == "text/plain":
return _parse_body(part)
if part_mime == "text/html" and not plain_fallback:
plain_fallback = _parse_body(part)
if part_mime.startswith("multipart/"):
nested = _parse_body(part)
if nested:
return nested
return plain_fallback
def _parse_date(raw: str) -> datetime:
"""Parse an RFC 2822 email date header into a UTC ``datetime``."""
try:
parsed = email.utils.parsedate_to_datetime(raw)
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed.astimezone(timezone.utc)
except Exception:
return datetime.now(timezone.utc)
class GmailClient:
"""Fetch email messages from a Gmail account via the Gmail REST API.
Parameters
----------
credentials_info:
Decrypted OAuth2 credential dict. Must contain at minimum
``token`` (access token) or ``refresh_token`` + ``token_uri`` +
``client_id`` + ``client_secret``.
"""
def __init__(self, credentials_info: dict[str, Any]) -> None:
from google.oauth2.credentials import Credentials
self._credentials_info = credentials_info
expiry_str: str | None = credentials_info.get("expiry")
expiry: datetime | None = None
if expiry_str:
try:
expiry = datetime.fromisoformat(
expiry_str.replace("Z", "+00:00")
).replace(tzinfo=timezone.utc)
except ValueError:
pass
self._credentials = Credentials(
token=credentials_info.get("token"),
refresh_token=credentials_info.get("refresh_token"),
token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"),
client_id=credentials_info.get("client_id"),
client_secret=credentials_info.get("client_secret"),
scopes=credentials_info.get("scopes"),
expiry=expiry,
)
# ── Public API ─────────────────────────────────────────────────────────
async def fetch_messages(
self,
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[EmailMessage]:
"""Return up to ``_MAX_MESSAGES`` emails matching *filter_config*.
Runs the synchronous Google API calls inside ``asyncio.to_thread()``
to avoid blocking the async event loop.
Token refresh is performed automatically when the access token has
expired. After the call, ``self.refreshed_credentials`` may be
consulted to detect whether new credentials should be persisted.
"""
query = _build_gmail_query(filter_config, since)
logger.debug("gmail: executing search query %r", query)
return await asyncio.to_thread(self._fetch_sync, query)
@property
def refreshed_credentials(self) -> dict[str, Any] | None:
"""Return updated credential dict if the access token was refreshed.
If the credentials were refreshed during ``fetch_messages()``, returns
a new dict that should be re-encrypted and written back to the DB.
Returns ``None`` if no refresh occurred.
"""
creds = self._credentials
if not creds.valid and creds.expired:
return None
# Check whether the token changed from what was stored.
if creds.token != self._credentials_info.get("token"):
result = {
"token": creds.token,
"refresh_token": creds.refresh_token,
"token_uri": creds.token_uri,
"client_id": creds.client_id,
"client_secret": creds.client_secret,
"scopes": list(creds.scopes or []),
}
if creds.expiry:
result["expiry"] = creds.expiry.isoformat()
return result
return None
# ── Internal sync worker ───────────────────────────────────────────────
def _fetch_sync(self, query: str) -> list[EmailMessage]:
"""Synchronous worker — called inside ``asyncio.to_thread()``."""
import googleapiclient.discovery
import googleapiclient.errors
from google.auth.transport.requests import Request
# Refresh token if needed before building the service.
if self._credentials.expired and self._credentials.refresh_token:
try:
self._credentials.refresh(Request())
except Exception as exc:
raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc
service = googleapiclient.discovery.build(
"gmail", "v1", credentials=self._credentials, cache_discovery=False
)
user_api = service.users() # type: ignore[attr-defined]
# ── List matching message IDs ──────────────────────────────────────
ids: list[str] = []
page_token: str | None = None
while len(ids) < _MAX_MESSAGES:
batch_size = min(100, _MAX_MESSAGES - len(ids))
kwargs: dict[str, Any] = {
"userId": "me",
"maxResults": batch_size,
}
if query:
kwargs["q"] = query
if page_token:
kwargs["pageToken"] = page_token
try:
resp = user_api.messages().list(**kwargs).execute()
except googleapiclient.errors.HttpError as exc:
raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc
for msg in resp.get("messages", []):
ids.append(msg["id"])
page_token = resp.get("nextPageToken")
if not page_token:
break
if not ids:
logger.debug("gmail: no messages matched query %r", query)
return []
logger.info("gmail: fetching %d message(s)", len(ids))
# ── Fetch individual message details ──────────────────────────────
messages: list[EmailMessage] = []
for msg_id in ids:
try:
msg = user_api.messages().get(
userId="me", id=msg_id, format="full"
).execute()
headers: dict[str, str] = {
h["name"].lower(): h["value"]
for h in msg.get("payload", {}).get("headers", [])
}
subject = headers.get("subject", "(no subject)")
sender = headers.get("from", "unknown")
date_raw = headers.get("date", "")
date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc)
body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE]
labels = msg.get("labelIds", [])
messages.append(EmailMessage(
id=msg_id,
subject=subject,
sender=sender,
body_text=body_text,
date=date,
labels=labels,
))
except googleapiclient.errors.HttpError as exc:
logger.warning("gmail: skipping message %s — HTTP error: %s", msg_id, exc)
except Exception as exc:
logger.warning("gmail: skipping message %s — unexpected error: %s", msg_id, exc)
logger.info("gmail: returned %d message(s)", len(messages))
return messages

View File

@@ -1,352 +0,0 @@
"""Microsoft Graph API client for Outlook and Teams cloud agent integration.
Handles two data sources:
* **Outlook email** (``provider="outlook"``) — ``fetch_emails()`` calls
``/me/messages`` with an OData ``$filter`` built from ``filter_config``.
* **Teams messages** (``provider="teams"``) — ``fetch_messages()`` calls
``/me/chats/getAllMessages`` filtered by date.
Authentication uses MSAL ``PublicClientApplication`` to acquire a token
from a stored refresh token. The ``httpx.AsyncClient`` (already a project
dependency) is used for all API calls.
Credential dict shape (Microsoft OAuth2 / MSAL):
{
"access_token": "<access_token>",
"refresh_token": "<refresh_token>",
"token_type": "Bearer",
"scope": "Mail.Read ChannelMessage.Read.All offline_access",
"expires_in": 3600
}
"""
from __future__ import annotations
import logging
import re
from datetime import datetime, timedelta, timezone
from typing import Any
import httpx
from app.config.settings import settings
from app.integrations import ChatMessage, EmailMessage
logger = logging.getLogger(__name__)
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
# Max items fetched per run.
_MAX_EMAILS = 200
_MAX_MESSAGES = 200
# Max characters of body forwarded to the LLM.
_BODY_TRUNCATE = 8_000
def _strip_html(raw: str) -> str:
"""Strip HTML tags and collapse whitespace."""
no_tags = re.sub(r"<[^>]+>", " ", raw)
import html as _html
decoded = _html.unescape(no_tags)
return re.sub(r"\s+", " ", decoded).strip()
def _odata_datetime(dt: datetime) -> str:
"""Format a datetime as an OData datetime literal (UTC, ISO 8601)."""
utc = dt.astimezone(timezone.utc)
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
def _build_email_filter(
filter_config: dict[str, Any] | None,
since: datetime | None,
) -> str:
"""Build an OData ``$filter`` expression for the ``/me/messages`` endpoint.
Supported ``filter_config`` keys:
senders (list[str]): Sender email addresses.
date_range (dict): ``{from: "<ISO-8601>", to: "<ISO-8601>"}``
folders (list[str]): Folder display names (not directly filterable
via OData, so ignored here — callers iterate
folder IDs separately if needed; listed for
completeness).
A hard ``since`` date always overrides ``date_range.from`` when it is
earlier.
"""
clauses: list[str] = []
cfg = filter_config or {}
# Senders.
senders: list[str] = cfg.get("senders", [])
if senders:
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
clauses.append("(" + " or ".join(sender_clauses) + ")")
# Date range.
date_range: dict = cfg.get("date_range", {})
from_str: str | None = date_range.get("from")
effective_since: datetime | None = since
if from_str:
try:
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
if cfg_since.tzinfo is None:
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
if effective_since is None or cfg_since > effective_since:
effective_since = cfg_since
except ValueError:
logger.warning("ms_graph: invalid date_range.from %r — ignoring", from_str)
if effective_since:
clauses.append(f"receivedDateTime ge {_odata_datetime(effective_since)}")
to_str: str | None = date_range.get("to")
if to_str:
try:
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
if to_dt.tzinfo is None:
to_dt = to_dt.replace(tzinfo=timezone.utc)
clauses.append(f"receivedDateTime le {_odata_datetime(to_dt)}")
except ValueError:
logger.warning("ms_graph: invalid date_range.to %r — ignoring", to_str)
return " and ".join(clauses)
class MSGraphClient:
"""Fetch emails and Teams messages via the Microsoft Graph REST API.
Parameters
----------
credentials_info:
Decrypted MSAL credential dict.
"""
def __init__(self, credentials_info: dict[str, Any]) -> None:
self._credentials_info = credentials_info
self._access_token: str = credentials_info.get("access_token", "")
self._original_access_token: str = self._access_token
self._refresh_token: str | None = credentials_info.get("refresh_token")
# ── Token management ───────────────────────────────────────────────────
def _auth_headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self._access_token}"}
async def _refresh_access_token(self) -> None:
"""Use MSAL to exchange the refresh token for a fresh access token.
Updates ``self._access_token`` and ``self._credentials_info`` in-place.
Raises:
RuntimeError: MSAL reports an auth error.
"""
import msal
app = msal.ConfidentialClientApplication(
client_id=settings.MS_CLIENT_ID,
client_credential=settings.MS_CLIENT_SECRET,
authority=f"https://login.microsoftonline.com/{settings.MS_TENANT_ID}",
)
scopes: list[str] = self._credentials_info.get("scope", "").split()
if not scopes:
scopes = ["https://graph.microsoft.com/.default"]
result = app.acquire_token_by_refresh_token(
self._refresh_token,
scopes=scopes,
)
if "access_token" not in result:
error = result.get("error_description", result.get("error", "unknown"))
raise RuntimeError(f"MS Graph token refresh failed: {error}")
self._access_token = result["access_token"]
# MSAL may issue a new refresh token.
if "refresh_token" in result:
self._refresh_token = result["refresh_token"]
self._credentials_info["refresh_token"] = result["refresh_token"]
self._credentials_info["access_token"] = self._access_token
@property
def refreshed_credentials(self) -> dict[str, Any] | None:
"""Return updated credential dict if the access token was refreshed.
Returns ``None`` if no change was made.
"""
if self._access_token != self._original_access_token:
return {**self._credentials_info, "access_token": self._access_token}
return None
# ── HTTP helpers ───────────────────────────────────────────────────────
async def _get(
self,
client: httpx.AsyncClient,
url: str,
params: dict[str, Any] | None = None,
*,
retry_on_401: bool = True,
) -> dict[str, Any]:
"""GET *url* with auth; refresh token on 401 and retry once."""
resp = await client.get(url, params=params, headers=self._auth_headers())
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
logger.debug("ms_graph: 401 on %s — refreshing token", url)
await self._refresh_access_token()
resp = await client.get(url, params=params, headers=self._auth_headers())
if resp.status_code == 429:
raise RuntimeError("MS Graph rate limit hit (429). Try again later.")
resp.raise_for_status()
return resp.json()
# ── Public API ─────────────────────────────────────────────────────────
async def fetch_emails(
self,
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[EmailMessage]:
"""Return up to ``_MAX_EMAILS`` Outlook messages matching *filter_config*.
Parameters
----------
filter_config:
Optional dict with ``senders``, ``date_range``, ``folders`` keys.
since:
Hard lower-bound on email date (from last agent run).
"""
odata_filter = _build_email_filter(filter_config, since)
params: dict[str, Any] = {
"$top": 50,
"$select": "id,subject,from,receivedDateTime,body,bodyPreview",
"$orderby": "receivedDateTime desc",
}
if odata_filter:
params["$filter"] = odata_filter
emails: list[EmailMessage] = []
url = f"{_GRAPH_BASE}/me/messages"
async with httpx.AsyncClient(timeout=30.0) as client:
while url and len(emails) < _MAX_EMAILS:
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
for item in data.get("value", []):
emails.append(self._parse_email(item))
if len(emails) >= _MAX_EMAILS:
break
url = data.get("@odata.nextLink", "")
params = {} # nextLink already contains encoded params.
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
return emails
async def fetch_messages(
self,
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[ChatMessage]:
"""Return up to ``_MAX_MESSAGES`` Teams messages matching *filter_config*.
Fetches from ``/me/chats/getAllMessages`` (personal + group chats).
The ``filter_config.channels`` key is checked as a text-filter on
the channel name post-fetch (the API doesn't support channel OData
filter directly on ``getAllMessages``).
"""
cfg = filter_config or {}
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
params: dict[str, Any] = {"$top": 50}
if since:
params["$filter"] = f"createdDateTime ge {_odata_datetime(since)}"
messages: list[ChatMessage] = []
url = f"{_GRAPH_BASE}/me/chats/getAllMessages"
async with httpx.AsyncClient(timeout=30.0) as client:
while url and len(messages) < _MAX_MESSAGES:
try:
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
except httpx.HTTPStatusError as exc:
# getAllMessages requires specific licensing; degrade gracefully.
if exc.response.status_code in (403, 404):
logger.warning(
"ms_graph: /me/chats/getAllMessages not available (%d) — "
"check Teams license or permissions",
exc.response.status_code,
)
break
raise
for item in data.get("value", []):
msg = self._parse_teams_message(item)
if channel_filter and msg.channel:
if not any(c in msg.channel.lower() for c in channel_filter):
continue
messages.append(msg)
if len(messages) >= _MAX_MESSAGES:
break
url = data.get("@odata.nextLink", "")
params = {}
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
return messages
# ── Parsers ────────────────────────────────────────────────────────────
@staticmethod
def _parse_email(item: dict[str, Any]) -> EmailMessage:
subject: str = item.get("subject", "(no subject)") or "(no subject)"
sender_block = item.get("from", {}) or {}
sender_addr = (
(sender_block.get("emailAddress") or {}).get("address", "unknown")
)
date_str: str = item.get("receivedDateTime", "")
try:
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
except Exception:
date = datetime.now(timezone.utc)
body_block = item.get("body", {}) or {}
content_type: str = body_block.get("contentType", "text")
raw_body: str = body_block.get("content", "")
if content_type == "html":
body_text = _strip_html(raw_body)
else:
body_text = raw_body or item.get("bodyPreview", "")
body_text = body_text[:_BODY_TRUNCATE]
return EmailMessage(
id=item.get("id", ""),
subject=subject,
sender=sender_addr,
body_text=body_text,
date=date,
)
@staticmethod
def _parse_teams_message(item: dict[str, Any]) -> ChatMessage:
msg_id: str = item.get("id", "")
sender_block = (item.get("from") or {}).get("user") or {}
sender: str = sender_block.get("displayName", "unknown")
channel: str | None = (item.get("channelIdentity") or {}).get("channelId")
date_str: str = item.get("createdDateTime", "")
try:
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
except Exception:
date = datetime.now(timezone.utc)
body_block = item.get("body", {}) or {}
content_type: str = body_block.get("contentType", "text")
raw_content: str = body_block.get("content", "")
content = _strip_html(raw_content) if content_type == "html" else raw_content
content = content[:_BODY_TRUNCATE]
return ChatMessage(
id=msg_id,
content=content,
sender=sender,
channel=channel,
date=date,
)

View File

@@ -1,72 +0,0 @@
from contextlib import asynccontextmanager
import logging
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
from app.api.middleware.rate_limit import TierRateLimitMiddleware
from app.api.middleware.sanitizer import SanitizerMiddleware
from app.config.settings import settings
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: ensure agent tool modules are loaded.
import app.agents # noqa: F401
yield
# Shutdown: dispose SQLAlchemy connection pool
from app.db import engine
await engine.dispose()
def create_app() -> FastAPI:
app = FastAPI(
title="Adiuva Cloud API",
version="0.1.0",
docs_url="/docs" if settings.ENV == "dev" else None,
redoc_url=None,
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Middleware stack (Starlette inserts at position 0, so last-added = outermost).
# Request flow: TierRateLimit → Sanitizer → CORS → Router
# Response flow: Router → CORS → Sanitizer → TierRateLimit
app.add_middleware(SanitizerMiddleware)
app.add_middleware(TierRateLimitMiddleware)
from app.api.routes import agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1")
app.include_router(storage.router, prefix="/api/v1")
app.include_router(vectors.router, prefix="/api/v1")
app.include_router(backup.router, prefix="/api/v1")
app.include_router(plugins.router, prefix="/api/v1")
app.include_router(billing.router, prefix="/api/v1")
app.include_router(agents.router, prefix="/api/v1")
app.include_router(device_ws.router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])
async def health() -> dict:
return {"status": "ok", "version": app.version}
return app
app = create_app()

View File

@@ -1,7 +0,0 @@
"""Plugin marketplace package.
Three service classes introduced in Step 10:
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
- ``ReviewQueue`` — approval workflow + security checklist
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
"""

View File

@@ -1,212 +0,0 @@
"""Plugin catalog registry backed by PostgreSQL.
Maintains the authoritative list of plugins, their review status, and
aggregate install counts. All data is persisted in the ``plugins`` table.
Module-level singleton::
from app.marketplace.plugin_registry import registry
"""
from __future__ import annotations
import json
from typing import Any, Literal
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import Plugin
from app.schemas import PluginListResponse, PluginManifest
_PAGE_SIZE = 20
def _plugin_to_manifest(p: Plugin) -> PluginManifest:
"""Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``."""
try:
permissions = json.loads(p.permissions) if p.permissions else []
except (json.JSONDecodeError, TypeError):
permissions = []
return PluginManifest(
id=p.id,
name=p.name,
description=p.description,
version=p.version,
author=p.author_name,
permissions=permissions,
category=p.category,
price_cents=p.price_cents,
)
class PluginRegistry:
"""PostgreSQL-backed plugin catalog.
All methods accept an ``AsyncSession`` parameter so the calling route
controls the session lifecycle.
"""
# ── Queries ──────────────────────────────────────────────────────
async def list_plugins(
self,
db: AsyncSession,
category: str | None = None,
query: str | None = None,
page: int = 1,
sort: Literal["rating", "installs", "newest"] = "newest",
) -> PluginListResponse:
"""Return a page of approved plugins, optionally filtered and sorted."""
base = select(Plugin).where(Plugin.status == "approved")
if category:
base = base.where(Plugin.category == category)
if query:
pattern = f"%{query}%"
base = base.where(
Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern)
)
# Count
count_q = select(func.count()).select_from(base.subquery())
total = (await db.execute(count_q)).scalar_one()
# Sort
if sort == "installs":
base = base.order_by(Plugin.install_count.desc())
elif sort == "rating":
base = base.order_by(Plugin.avg_rating.desc())
else: # newest
base = base.order_by(Plugin.created_at.desc())
base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE)
rows = (await db.execute(base)).scalars().all()
return PluginListResponse(
plugins=[_plugin_to_manifest(r) for r in rows],
total=total,
page=page,
)
async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None:
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
p = result.scalar_one_or_none()
if p is None:
return None
return {
"manifest": _plugin_to_manifest(p),
"status": p.status,
"install_count": p.install_count,
"avg_rating": p.avg_rating,
}
# ── Mutations ────────────────────────────────────────────────────
async def submit_plugin(
self,
db: AsyncSession,
manifest: PluginManifest,
package_s3_key: str,
) -> str:
"""Add *manifest* to the catalog with ``status='pending_review'``.
Returns the plugin_id. If a plugin with the same id already exists
it is overwritten (re-submission after rejection).
"""
plugin_id = manifest.id
existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = existing.scalar_one_or_none()
if row is not None:
row.name = manifest.name
row.description = manifest.description
row.version = manifest.version
row.author_name = manifest.author
row.category = manifest.category
row.price_cents = manifest.price_cents
row.permissions = json.dumps(manifest.permissions)
row.status = "pending_review"
row.s3_package_key = package_s3_key
row.rejection_reason = None
else:
row = Plugin(
id=plugin_id,
name=manifest.name,
description=manifest.description,
version=manifest.version,
author_name=manifest.author,
category=manifest.category,
price_cents=manifest.price_cents,
permissions=json.dumps(manifest.permissions),
status="pending_review",
s3_package_key=package_s3_key,
install_count=0,
avg_rating=0.0,
)
db.add(row)
await db.commit()
return plugin_id
async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None:
"""Set *plugin_id* status to ``'approved'``.
Raises ``KeyError`` if the plugin is not found.
"""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is None:
raise KeyError(f"Plugin not found: {plugin_id}")
row.status = "approved"
row.rejection_reason = None
await db.commit()
async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None:
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
Raises ``KeyError`` if the plugin is not found.
"""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is None:
raise KeyError(f"Plugin not found: {plugin_id}")
row.status = "rejected"
row.rejection_reason = reason
await db.commit()
async def record_install(self, db: AsyncSession, plugin_id: str) -> None:
"""Increment the install count for *plugin_id* (no-op if not found)."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is not None:
row.install_count = row.install_count + 1
await db.commit()
async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None:
"""Decrement the install count for *plugin_id*, floored at 0."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is not None:
row.install_count = max(0, row.install_count - 1)
await db.commit()
# ── Internal helpers used by ReviewQueue ─────────────────────────
async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]:
"""Return all entries with status='pending_review'."""
result = await db.execute(
select(Plugin).where(Plugin.status == "pending_review")
)
rows = result.scalars().all()
return [
{
"manifest": _plugin_to_manifest(r),
"submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0,
}
for r in rows
]
# Module-level singleton
registry = PluginRegistry()

View File

@@ -1,125 +0,0 @@
"""Plugin review workflow backed by PostgreSQL.
Manages the approval queue for newly submitted plugins and enforces a
security checklist before any plugin is made visible in the marketplace.
Module-level singleton::
from app.marketplace.plugin_review import review_queue
"""
from __future__ import annotations
import re
from typing import Any, Literal
from sqlalchemy.ext.asyncio import AsyncSession
from app.marketplace.plugin_registry import registry
from app.models import PluginReview as PluginReviewModel
from app.schemas import PluginManifest
# ── Security policy ───────────────────────────────────────────────────
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
{
"read:tasks",
"write:tasks",
"read:projects",
"write:projects",
"read:notes",
"write:notes",
"read:timelines",
"write:timelines",
"read:calendar",
"write:calendar",
}
)
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
def validate_manifest(manifest: PluginManifest) -> None:
"""Enforce the plugin security checklist.
Raises:
``ValueError`` on the first violation found. Callers should catch
this and return HTTP 422 / reject the submission.
Checks:
1. Plugin id matches ``^[a-z0-9-]+$``
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
3. No manifest field contains raw binary data
"""
if not _PLUGIN_ID_RE.match(manifest.id):
raise ValueError(
f"Invalid plugin id format: '{manifest.id}'. "
"Only lowercase letters, digits, and hyphens are allowed."
)
for perm in manifest.permissions:
if perm not in ALLOWED_PERMISSIONS:
raise ValueError(
f"Unknown permission: '{perm}'. "
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
)
for field_name, value in manifest.model_dump().items():
if isinstance(value, (bytes, bytearray)):
raise ValueError(
f"Binary content is not allowed in manifest field '{field_name}'."
)
class ReviewQueue:
"""Approval queue for pending plugin submissions.
Delegates status changes to the shared ``PluginRegistry`` singleton.
Review records are persisted in the ``plugin_reviews`` table.
"""
async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]:
"""Return all plugins currently awaiting review.
Each item is ``{plugin_id, manifest, submitted_at}``.
"""
entries = await registry.get_pending_entries(db)
return [
{
"plugin_id": e["manifest"].id,
"manifest": e["manifest"],
"submitted_at": e["submitted_at"],
}
for e in entries
]
async def submit_review(
self,
db: AsyncSession,
plugin_id: str,
reviewer_id: str,
decision: Literal["approved", "rejected"],
notes: str = "",
) -> None:
"""Record a review decision and update the plugin's status.
Raises:
``KeyError`` if *plugin_id* is not found in the registry.
"""
if decision == "approved":
await registry.approve_plugin(db, plugin_id)
else:
await registry.reject_plugin(db, plugin_id, reason=notes)
review = PluginReviewModel(
plugin_id=plugin_id,
reviewer_id=reviewer_id,
decision=decision,
notes=notes,
)
db.add(review)
await db.commit()
# Module-level singleton
review_queue = ReviewQueue()

View File

@@ -1,233 +0,0 @@
"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL.
Records every plugin installation as a revenue event and facilitates
70 % / 30 % payouts to developers via Stripe Connect. Data is persisted
in the ``revenue_events`` table.
Module-level singleton::
from app.marketplace.revenue_share import revenue_share
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import Any
import stripe as stripe_lib
from sqlalchemy import extract, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
from app.marketplace.plugin_registry import registry
from app.models import Plugin, RevenueEvent
logger = logging.getLogger(__name__)
# ── Revenue split constants ───────────────────────────────────────────
DEVELOPER_SHARE: float = 0.70
PLATFORM_SHARE: float = 0.30
class RevenueShare:
"""Records installation revenue events and coordinates developer payouts.
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
is not configured, consistent with the rest of the billing layer.
"""
# ── Helpers ──────────────────────────────────────────────────────
@staticmethod
def _stripe_configured() -> bool:
return bool(settings.STRIPE_SECRET_KEY)
@staticmethod
def _stripe() -> Any:
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
return stripe_lib
# ── Core operations ──────────────────────────────────────────────
async def record_install(
self,
db: AsyncSession,
plugin_id: str,
user_id: str,
amount_cents: int,
) -> None:
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
For free plugins (``amount_cents == 0``) no payment is initiated but
the event is still recorded for analytics.
For paid plugins the developer receives 70 % via a Stripe Connect
destination charge. If Stripe is not configured or the charge fails
the installation still succeeds (the event is recorded and the install
count is incremented) — a warning is logged for monitoring.
"""
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
stripe_transfer_id: str | None = None
if amount_cents > 0 and self._stripe_configured():
# Look up the plugin's author Stripe account from the DB
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
plugin_row = result.scalar_one_or_none()
developer_stripe_account: str | None = None
if plugin_row and plugin_row.author_id:
# Future: look up user.stripe_connect_account_id
developer_stripe_account = None # no real account yet
if developer_stripe_account:
try:
s = self._stripe()
transfer = s.Transfer.create(
amount=developer_share_cents,
currency="eur",
destination=developer_stripe_account,
description=f"Revenue share for plugin {plugin_id}",
metadata={"plugin_id": plugin_id, "user_id": user_id},
)
stripe_transfer_id = transfer["id"]
except Exception as exc:
logger.warning(
"Stripe Connect transfer failed for plugin %s: %s",
plugin_id,
exc,
)
else:
logger.debug(
"No Stripe account on file for plugin %s developer; "
"skipping transfer.",
plugin_id,
)
event = RevenueEvent(
plugin_id=plugin_id,
user_id=user_id,
amount_cents=amount_cents,
developer_share_cents=developer_share_cents,
stripe_transfer_id=stripe_transfer_id,
)
db.add(event)
await db.commit()
await registry.record_install(db, plugin_id)
async def get_earnings(
self,
db: AsyncSession,
developer_id: str,
period: str | None = None,
) -> dict[str, Any]:
"""Return aggregated earnings for *developer_id*.
``period`` is an optional ``YYYY-MM`` string to restrict the window.
Returns::
{
"developer_id": str,
"period": str | None,
"total_installs": int,
"total_revenue_cents": int,
"developer_share_cents": int,
}
"""
# Find plugin ids belonging to this developer (by author_name match)
plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id)
plugin_result = await db.execute(plugin_q)
developer_plugin_ids = [row[0] for row in plugin_result.all()]
if not developer_plugin_ids:
return {
"developer_id": developer_id,
"period": period,
"total_installs": 0,
"total_revenue_cents": 0,
"developer_share_cents": 0,
}
query = select(
func.count().label("total_installs"),
func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"),
func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"),
).where(RevenueEvent.plugin_id.in_(developer_plugin_ids))
if period:
# Filter by YYYY-MM: extract year and month from created_at
try:
year, month = period.split("-")
query = query.where(
extract("year", RevenueEvent.created_at) == int(year),
extract("month", RevenueEvent.created_at) == int(month),
)
except ValueError:
pass # invalid period format — return all
result = await db.execute(query)
row = result.one()
return {
"developer_id": developer_id,
"period": period,
"total_installs": row.total_installs,
"total_revenue_cents": row.total_revenue,
"developer_share_cents": row.dev_share,
}
async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None:
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
Marks processed events with ``paid_at`` timestamp.
Stubs gracefully when Stripe is not configured.
"""
try:
year, month = period.split("-")
year_int, month_int = int(year), int(month)
except ValueError:
logger.warning("Invalid period format: %s", period)
return
result = await db.execute(
select(RevenueEvent).where(
RevenueEvent.plugin_id == plugin_id,
RevenueEvent.paid_at.is_(None),
extract("year", RevenueEvent.created_at) == year_int,
extract("month", RevenueEvent.created_at) == month_int,
)
)
unpaid = list(result.scalars().all())
total_dev_share = sum(e.developer_share_cents for e in unpaid)
if total_dev_share <= 0 or not unpaid:
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
return
if self._stripe_configured():
plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
plugin_row = plugin_result.scalar_one_or_none()
developer_stripe_account: str | None = None # Future: fetch from DB
if plugin_row and developer_stripe_account:
try:
s = self._stripe()
s.Transfer.create(
amount=total_dev_share,
currency="eur",
destination=developer_stripe_account,
description=f"Payout for plugin {plugin_id} period {period}",
)
except Exception as exc:
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
return
paid_ts = datetime.now(timezone.utc)
for event in unpaid:
event.paid_at = paid_ts
await db.commit()
# Module-level singleton
revenue_share = RevenueShare()

View File

@@ -1,476 +0,0 @@
"""SQLAlchemy ORM models for all persistent tables.
Only auth, billing, storage metadata, and marketplace data live here.
User content (notes, tasks, etc.) is NEVER persisted server-side —
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
Table inventory:
users — account credentials + tier
refresh_tokens — hashed refresh token store
subscriptions — Stripe subscription records
storage_records — S3 blob metadata (no plaintext)
backup_metadata — encrypted backup manifests
plugins — marketplace plugin catalog
plugin_installations — per-user install records
plugin_reviews — admin review decisions
revenue_events — Stripe Connect 70/30 split ledger
memory_core — per-user persistent key/value preferences (encrypted)
memory_associative — per-user semantic memory with embeddings (encrypted)
memory_episodic — per-user session summaries (encrypted)
memory_proactive — per-user behavioral patterns (encrypted)
"""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from sqlalchemy import (
BigInteger,
Boolean,
DateTime,
Enum,
Float,
ForeignKey,
Integer,
JSON,
String,
Text,
UniqueConstraint,
Uuid,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db import Base
# ── Helpers ──────────────────────────────────────────────────────────────
def _uuid() -> str:
return str(uuid.uuid4())
def _now() -> datetime:
return datetime.now(timezone.utc)
# ── Enum types ────────────────────────────────────────────────────────────
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
# ── Models ────────────────────────────────────────────────────────────────
class User(Base):
__tablename__ = "users"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
# Used to encrypt/decrypt all memory rows for this user.
encryption_key: Mapped[str | None] = mapped_column(String(64), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
refresh_tokens: Mapped[list[RefreshToken]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)
subscription: Mapped[Subscription | None] = relationship(
back_populates="user", uselist=False, cascade="all, delete-orphan"
)
class RefreshToken(Base):
__tablename__ = "refresh_tokens"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
user: Mapped[User] = relationship(back_populates="refresh_tokens")
class Subscription(Base):
__tablename__ = "subscriptions"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, unique=True, index=True
)
stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
status: Mapped[str] = mapped_column(String(50), nullable=False, default="free")
current_period_end: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
user: Mapped[User] = relationship(back_populates="subscription")
class StorageRecord(Base):
__tablename__ = "storage_records"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class BackupMetadata(Base):
__tablename__ = "backup_metadata"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
version: Mapped[int] = mapped_column(Integer, nullable=False)
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
class Plugin(Base):
__tablename__ = "plugins"
id: Mapped[str] = mapped_column(String(255), primary_key=True)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
# nullable until developer account system is built
author_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
submitted_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
installations: Mapped[list[PluginInstallation]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
reviews: Mapped[list[PluginReview]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
revenue_events: Mapped[list[RevenueEvent]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
class PluginInstallation(Base):
__tablename__ = "plugin_installations"
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
installed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="installations")
class PluginReview(Base):
__tablename__ = "plugin_reviews"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
reviewer_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
reviewed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
class RevenueEvent(Base):
__tablename__ = "revenue_events"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
class LocalAgentConfig(Base):
__tablename__ = "local_agent_configs"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
device_id: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
run_logs: Mapped[list[AgentRunLog]] = relationship(
back_populates="local_agent",
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
foreign_keys="AgentRunLog.agent_id",
cascade="all, delete-orphan",
overlaps="run_logs,cloud_agent",
)
class CloudAgentConfig(Base):
__tablename__ = "cloud_agent_configs"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
provider: Mapped[str] = mapped_column(CloudProviderEnum, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
oauth_token_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True)
filter_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
run_logs: Mapped[list[AgentRunLog]] = relationship(
back_populates="cloud_agent",
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
foreign_keys="AgentRunLog.agent_id",
cascade="all, delete-orphan",
overlaps="run_logs,local_agent",
)
class AgentRunLog(Base):
__tablename__ = "agent_run_logs"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
# Plain string — not a FK because it references either local_agent_configs or cloud_agent_configs
# depending on agent_type. Query by (agent_id, agent_type) to locate the source config.
agent_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
agent_type: Mapped[str] = mapped_column(AgentTypeEnum, nullable=False)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
started_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
local_agent: Mapped[LocalAgentConfig | None] = relationship(
back_populates="run_logs",
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
foreign_keys="AgentRunLog.agent_id",
overlaps="run_logs,cloud_agent",
)
cloud_agent: Mapped[CloudAgentConfig | None] = relationship(
back_populates="run_logs",
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
foreign_keys="AgentRunLog.agent_id",
overlaps="run_logs,local_agent",
)
# ── Memory models ─────────────────────────────────────────────────────────────
class MemoryCore(Base):
"""Per-user persistent key/value preferences, encrypted at rest.
Examples: preferred_language, timezone, work_style.
Decrypted in-memory only using User.encryption_key.
"""
__tablename__ = "memory_core"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
key: Mapped[str] = mapped_column(String(255), nullable=False)
value_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class MemoryAssociative(Base):
"""Per-user semantic memory: encrypted content + pgvector embedding for similarity search.
Production: ``embedding`` column is ``vector(1536)`` via pgvector.
Tests (SQLite): stored as JSON list.
"""
__tablename__ = "memory_associative"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
# JSON-encoded float list in SQLite tests; vector(1536) in Postgres via migration.
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class MemoryEpisodic(Base):
"""Per-user session summaries, encrypted at rest.
One row per session interaction; used to recall recent conversations.
"""
__tablename__ = "memory_episodic"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
summary_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
session_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
class MemoryProactive(Base):
"""Per-user inferred behavioral patterns, encrypted at rest.
Confidence in [0.0, 1.0]; only patterns above threshold are injected.
Source: 'inferred' (from episodes) or 'explicit' (user-stated).
"""
__tablename__ = "memory_proactive"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
pattern_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.5)
source: Mapped[str] = mapped_column(String(50), nullable=False, default="inferred")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)

View File

@@ -1,321 +0,0 @@
"""Pydantic schemas — API request/response contracts.
Mirrors the TypeScript types from the Electron app (src/shared/api-types.ts).
"""
from __future__ import annotations
from enum import Enum
from typing import Any, Literal
from pydantic import BaseModel, Field
# ── Billing ──────────────────────────────────────────────────────────
BillingTier = Literal["free", "pro", "power", "team"]
# ── Auth ─────────────────────────────────────────────────────────────
class AuthTokens(BaseModel):
access_token: str
refresh_token: str
expires_at: int
class UserProfile(BaseModel):
id: str
email: str
name: str | None = None
surname: str | None = None
tier: BillingTier
# ── Chat ─────────────────────────────────────────────────────────────
class ChatContext(BaseModel):
user_profile: dict[str, Any] = Field(default_factory=dict)
relevant_documents: list[str] = Field(default_factory=list)
recent_tasks: list[dict[str, Any]] = Field(default_factory=list)
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
class ChatRequest(BaseModel):
message: str
context: ChatContext = Field(default_factory=ChatContext)
class ChatResponse(BaseModel):
response: str
# ── Backup ───────────────────────────────────────────────────────────
class BackupMetadata(BaseModel):
version: int
timestamp: int
checksum: str
chunk_count: int
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
class StorageRecord(BaseModel):
id: str
user_id: str
table: str
blob: bytes
checksum: str
created_at: int
updated_at: int
class StorageRecordCreate(BaseModel):
table: str
blob: bytes
checksum: str
class StorageRecordUpdate(BaseModel):
blob: bytes
checksum: str
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
class VectorItem(BaseModel):
id: str
blob: bytes # encrypted vector + metadata — backend never decrypts
checksum: str
class VectorUpsertRequest(BaseModel):
vectors: list[VectorItem]
class VectorSearchRequest(BaseModel):
query_blob: bytes # encrypted query — backend never decrypts
top_k: int = 10
class VectorSearchResult(BaseModel):
id: str
score: float
blob: bytes
class VectorSearchResponse(BaseModel):
results: list[VectorSearchResult]
# ── Plugin Marketplace ────────────────────────────────────────────────
class PluginManifest(BaseModel):
id: str
name: str
description: str
version: str
author: str
permissions: list[str]
category: str
price_cents: int = 0
class PluginListResponse(BaseModel):
plugins: list[PluginManifest]
total: int
page: int
class PluginInstallRequest(BaseModel):
plugin_id: str
# ── WebSocket Frame Protocol ──────────────────────────────────────────
class WsFrameType(str, Enum):
# ── v2 frame types (kept for backward compat) ──────────────────────
chat_request = "chat_request"
text_chunk = "text_chunk"
tool_call = "tool_call"
tool_result = "tool_result"
final = "final"
ping = "ping"
device_hello = "device_hello"
# ── v3 frame types ─────────────────────────────────────────────────
home_request = "home_request"
floating_request = "floating_request"
stream_start = "stream_start"
stream_text = "stream_text"
stream_end = "stream_end"
floating_domain = "floating_domain"
data_request = "data_request"
data_response = "data_response"
mutation = "mutation"
# ── v4 journey frame types ────────────────────────────────────────
journey_start = "journey_start"
journey_message = "journey_message"
journey_reply = "journey_reply"
class WsToolCall(BaseModel):
"""Server → Client: requests a CRUD/vector operation on the local DB."""
type: Literal[WsFrameType.tool_call] = WsFrameType.tool_call
id: str
action: str
table: str | None = None
data: dict[str, Any] | None = None
filters: dict[str, Any] | None = None
vector: list[float] | None = None
limit: int | None = None
class WsToolResult(BaseModel):
"""Client → Server: result of a CRUD/vector operation."""
type: Literal[WsFrameType.tool_result] = WsFrameType.tool_result
id: str
row: dict[str, Any] | None = None
rows: list[dict[str, Any]] | None = None
results: list[dict[str, Any]] | None = None
deleted: bool | None = None
ok: bool | None = None
error: str | None = None
class WsTextChunk(BaseModel):
"""Server → Client: incremental LLM response text."""
type: Literal[WsFrameType.text_chunk] = WsFrameType.text_chunk
text: str
class WsFinal(BaseModel):
"""Server → Client: signals end of response with the complete text."""
type: Literal[WsFrameType.final] = WsFrameType.final
response: str
# ── WebSocket Agent Frame Protocol ────────────────────────────────────
class WsDeviceHello(BaseModel):
"""Client → Server: device identification on WS connect."""
type: Literal[WsFrameType.device_hello] = WsFrameType.device_hello
device_id: str
agent_ids: list[str] = Field(default_factory=list)
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
class WsFloatingScope(BaseModel):
"""Scope for a floating request — narrows the agent to a specific entity."""
type: Literal["task", "project", "note", "timeline"]
id: str | None = None
class WsHomeRequest(BaseModel):
"""Client → Server: Home chat message."""
type: Literal[WsFrameType.home_request] = WsFrameType.home_request
message: str
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
class WsFloatingRequest(BaseModel):
"""Client → Server: Floating chat message scoped to an entity."""
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
message: str
scope: WsFloatingScope
class WsStreamStart(BaseModel):
"""Server → Client: signals start of a streaming response."""
type: Literal[WsFrameType.stream_start] = WsFrameType.stream_start
request_id: str
class WsStreamText(BaseModel):
"""Server → Client: streamed text token."""
type: Literal[WsFrameType.stream_text] = WsFrameType.stream_text
request_id: str
chunk: str
class WsStreamEnd(BaseModel):
"""Server → Client: signals end of a streaming response."""
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
request_id: str
class WsDomain(BaseModel):
"""Structured floating domain payload for UI routing decisions."""
type: Literal["task", "timeline", "project", "node"]
id: str | None = None
section: Literal["task", "timeline", "note"] | None = None
class WsFloatingDomain(BaseModel):
"""Server → Client: domain determined for a floating request."""
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
request_id: str
domain: WsDomain
# ── Agent Catalog ─────────────────────────────────────────────────────
class AgentCatalogItem(BaseModel):
type: str
name: str
description: str
class AgentCreationCheckRequest(BaseModel):
active_agents: int = Field(ge=0, default=0)
class AgentCreationCheckResponse(BaseModel):
allowed: bool
tier: BillingTier
active_agents: int
limit: int
class AgentTriggerRequest(BaseModel):
directory: str = Field(min_length=1)
device_id: str = Field(default="")
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
what_to_extract: list[str] = Field(min_length=1)
actions_by_type: dict[str, list[str]] | None = None
batch_interval: str = Field(min_length=1)
custom_agent_prompt: str = Field(min_length=1)
active_agents: int = Field(ge=0, default=0)
# ── Agent Run Log ─────────────────────────────────────────────────────
class AgentRunLogResponse(BaseModel):
id: str
agent_id: str
agent_type: Literal["local", "cloud"]
status: Literal["running", "success", "error", "partial"]
items_processed: int
items_created: int
errors: list[str]
started_at: int
completed_at: int | None
# ── Chatbot Journey ───────────────────────────────────────────────────

View File

@@ -1 +0,0 @@
"""Cloud storage layer — E2E encrypted blobs and vectors."""

View File

@@ -1,106 +0,0 @@
"""S3-backed store for E2E-encrypted blobs.
Keys are structured as ``{user_id}/{table}/{record_id}``.
The backend never inspects blob content — it stores and retrieves opaque bytes.
"""
from __future__ import annotations
from typing import Any
import boto3
from app.config.settings import settings
class BlobStore:
"""Thin wrapper around boto3 S3.
All blobs must be E2E encrypted by the client before upload.
The backend adds SSE-S3 as an extra layer of at-rest encryption
but cannot decrypt the inner client-side payload.
"""
def _client(self) -> Any:
kwargs: dict[str, Any] = {
"region_name": settings.S3_REGION,
"aws_access_key_id": settings.AWS_ACCESS_KEY_ID,
"aws_secret_access_key": settings.AWS_SECRET_ACCESS_KEY,
}
if settings.S3_ENDPOINT_URL and isinstance(settings.S3_ENDPOINT_URL, str):
kwargs["endpoint_url"] = settings.S3_ENDPOINT_URL
return boto3.client("s3", **kwargs)
@staticmethod
def _key(user_id: str, table: str, record_id: str) -> str:
return f"{user_id}/{table}/{record_id}"
async def upload(
self,
user_id: str,
table: str,
record_id: str,
blob: bytes,
checksum: str,
) -> str:
"""Store *blob* in S3 and return the S3 key.
Args:
user_id: Owner of the blob (used as key prefix).
table: Logical table name (e.g. ``"tasks"``).
record_id: Record UUID.
blob: Raw bytes (pre-encrypted by client).
checksum: SHA-256 hex digest supplied by the client; stored as
object metadata for download-time verification.
Returns:
The S3 key under which the blob was stored.
"""
key = self._key(user_id, table, record_id)
self._client().put_object(
Bucket=settings.S3_BUCKET,
Key=key,
Body=blob,
ServerSideEncryption="AES256", # SSE-S3 at rest
Metadata={"checksum": checksum},
)
return key
async def download(self, user_id: str, s3_key: str) -> bytes:
"""Retrieve the blob stored at *s3_key*.
*user_id* is retained in the signature so higher-level code can
enforce ownership without re-parsing the key.
Raises:
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
object does not exist.
"""
response = self._client().get_object(
Bucket=settings.S3_BUCKET,
Key=s3_key,
)
return response["Body"].read()
async def delete(self, user_id: str, s3_key: str) -> None:
"""Delete the object at *s3_key*.
S3 ``delete_object`` is idempotent — it succeeds even if the key does
not exist.
"""
self._client().delete_object(
Bucket=settings.S3_BUCKET,
Key=s3_key,
)
async def list_keys(self, user_id: str, table: str) -> list[str]:
"""Return all S3 keys for a given user + table combination.
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
"""
prefix = f"{user_id}/{table}/"
response = self._client().list_objects_v2(
Bucket=settings.S3_BUCKET,
Prefix=prefix,
)
return [obj["Key"] for obj in response.get("Contents", [])]

View File

@@ -1,32 +0,0 @@
"""Integrity verification only — the backend NEVER decrypts user data."""
from __future__ import annotations
import hashlib
import hmac
from fastapi import HTTPException
def verify_checksum(blob: bytes, checksum: str) -> bool:
"""Return ``True`` if SHA-256(blob) matches *checksum*.
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
timing-based side-channel attacks.
"""
computed = hashlib.sha256(blob).hexdigest()
return hmac.compare_digest(computed, checksum)
def reject_if_tampered(blob: bytes, checksum: str) -> None:
"""Raise ``HTTP 400`` if the blob does not match its checksum.
Call this before storing or forwarding any client-provided blob.
The backend never holds decryption keys — this check only verifies
that the opaque bytes arrived intact.
"""
if not verify_checksum(blob, checksum):
raise HTTPException(
status_code=400,
detail="Checksum mismatch: blob integrity check failed",
)

View File

@@ -1,205 +0,0 @@
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
Vectors are pre-encrypted blobs from the client. The backend stores them
alongside a deterministic 32-dim float representation derived from the blob's
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
is a known trade-off documented in the backend plan.
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
``user_id`` payload field on a shared collection.
"""
from __future__ import annotations
import base64
import hashlib
from typing import Any
from pinecone import Pinecone
from qdrant_client import QdrantClient
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
from app.config.settings import settings
from app.schemas import VectorItem, VectorSearchResult
_QDRANT_COLLECTION = "adiuva_vectors"
def _blob_to_vector(blob: bytes) -> list[float]:
"""Derive a 32-dim float vector from *blob* for storage purposes only.
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
normalises each byte to the range [-1.0, 1.0]. This vector carries no
semantic meaning on encrypted data.
"""
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
class VectorStore:
"""Thin wrapper around Pinecone or Qdrant.
The backend to use is selected at runtime:
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
"""
def _use_pinecone(self) -> bool:
return bool(settings.PINECONE_API_KEY)
# ── Pinecone helpers ──────────────────────────────────────────────
def _pinecone_index(self) -> Any:
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
return pc.Index(settings.PINECONE_INDEX)
# ── Qdrant helpers ────────────────────────────────────────────────
def _qdrant_client(self) -> Any:
return QdrantClient(
url=settings.QDRANT_URL,
api_key=settings.QDRANT_API_KEY or None,
)
# ── Public API ────────────────────────────────────────────────────
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
"""Store encrypted vectors in the backend.
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
so it can be returned verbatim during search.
Args:
user_id: Used as Pinecone namespace or Qdrant payload field.
vectors: List of encrypted vector items from the client.
"""
if self._use_pinecone():
await self._pinecone_upsert(user_id, vectors)
else:
await self._qdrant_upsert(user_id, vectors)
async def search(
self,
user_id: str,
query_blob: bytes,
top_k: int,
) -> list[VectorSearchResult]:
"""Query the vector store and return encrypted result blobs.
The query vector is derived from *query_blob* using the same
deterministic mapping as upsert.
Args:
user_id: Scopes the search to this user's namespace.
query_blob: Encrypted query from the client.
top_k: Maximum number of results to return.
Returns:
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
"""
if self._use_pinecone():
return await self._pinecone_search(user_id, query_blob, top_k)
return await self._qdrant_search(user_id, query_blob, top_k)
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
"""Remove vectors by ID, scoped to *user_id*.
Args:
user_id: Namespace / payload filter to prevent cross-user deletion.
vector_ids: List of vector IDs to remove.
"""
if self._use_pinecone():
await self._pinecone_delete(user_id, vector_ids)
else:
await self._qdrant_delete(user_id, vector_ids)
# ── Pinecone implementation ───────────────────────────────────────
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
index = self._pinecone_index()
records = [
{
"id": v.id,
"values": _blob_to_vector(v.blob),
"metadata": {
"blob": base64.b64encode(v.blob).decode(),
"checksum": v.checksum,
"user_id": user_id,
},
}
for v in vectors
]
index.upsert(vectors=records, namespace=user_id)
async def _pinecone_search(
self, user_id: str, query_blob: bytes, top_k: int
) -> list[VectorSearchResult]:
index = self._pinecone_index()
query_vector = _blob_to_vector(query_blob)
response = index.query(
vector=query_vector,
top_k=top_k,
namespace=user_id,
include_metadata=True,
)
results: list[VectorSearchResult] = []
for match in response.get("matches", []):
blob_bytes = base64.b64decode(match["metadata"]["blob"])
results.append(
VectorSearchResult(
id=match["id"],
score=match["score"],
blob=blob_bytes,
)
)
return results
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
index = self._pinecone_index()
index.delete(ids=vector_ids, namespace=user_id)
# ── Qdrant implementation ─────────────────────────────────────────
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
client = self._qdrant_client()
points = [
PointStruct(
id=v.id,
vector=_blob_to_vector(v.blob),
payload={
"blob": base64.b64encode(v.blob).decode(),
"checksum": v.checksum,
"user_id": user_id,
},
)
for v in vectors
]
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
async def _qdrant_search(
self, user_id: str, query_blob: bytes, top_k: int
) -> list[VectorSearchResult]:
client = self._qdrant_client()
query_vector = _blob_to_vector(query_blob)
hits = client.search(
collection_name=_QDRANT_COLLECTION,
query_vector=query_vector,
query_filter=Filter(
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
),
limit=top_k,
)
return [
VectorSearchResult(
id=str(hit.id),
score=hit.score,
blob=base64.b64decode(hit.payload["blob"]),
)
for hit in hits
]
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
client = self._qdrant_client()
client.delete(
collection_name=_QDRANT_COLLECTION,
points_selector=PointIdsList(points=vector_ids),
)

View File

@@ -1,37 +0,0 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
langchain>=0.3.0
langchain-openai>=0.3.0
langchain-litellm>=0.1.0
litellm>=1.50.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
python-jose[cryptography]>=3.3.0
stripe>=11.0.0
boto3>=1.35.0
slowapi>=0.1.9
sqlalchemy>=2.0.0
asyncpg>=0.30.0
alembic>=1.14.0
bcrypt>=4.2.0
python-dotenv>=1.0.0
httpx>=0.28.0
websockets>=14.0
psycopg2-binary>=2.9.0
pytest>=8.0.0
pytest-asyncio>=0.24.0
aiosqlite>=0.20.0
moto[s3]>=5.0.0
pinecone>=5.0.0
qdrant-client>=1.7.0
croniter>=3.0.0
google-api-python-client>=2.130.0
google-auth>=2.29.0
google-auth-oauthlib>=1.2.0
google-auth-httplib2>=0.2.0
msal>=1.28.0
cryptography>=42.0.0
redis>=5.0.0
langfuse>=3.0.0
ruff>=0.8.0