Refactor tests for execution plan and add comprehensive storage tests

- Updated `TestModuleSingletons` in `test_execution_plan.py` to reflect new agent templates and playbook names.
- Changed assertions in playbook tests to match updated templates and agents.
- Introduced `test_storage.py` to cover the storage layer, including encryption, BlobStore, and VectorStore functionalities.
- Added tests for S3 interactions, ensuring upload, download, delete, and list operations work as expected.
- Implemented mock tests for Pinecone and Qdrant vector stores to validate upsert, search, and delete operations.
This commit is contained in:
2026-03-02 15:36:09 +01:00
parent 35dd9ac86f
commit c8ef7b119b
21 changed files with 1980 additions and 469 deletions

View File

@@ -1,5 +1,5 @@
"""Import all agent modules to trigger @registry.register decorators.""" """Import all agent modules to trigger @registry.register decorators."""
from app.agents import analytics_agent, calendar_agent, email_agent, task_agent from app.agents import checkpoint_agent, note_agent, project_agent, task_agent
__all__ = ["analytics_agent", "calendar_agent", "email_agent", "task_agent"] __all__ = ["checkpoint_agent", "note_agent", "project_agent", "task_agent"]

View File

@@ -1,80 +0,0 @@
"""Analytics agent — metrics, reports, and trend analysis."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from app.config.settings import settings
from app.core.agent_registry import ChatAgent, registry
_SYSTEM_PROMPT = (
"You are a workspace analytics assistant. Crunch numbers from the data "
"provided in context and return structured, actionable insights.\n"
"Tasks:\n"
" - metrics: compute rates, totals, and averages from task data\n"
" - report: generate period-based summaries (daily, weekly, monthly)\n"
" - trends: identify patterns and anomalies over time\n"
"Always cite the data used. Do not fabricate figures."
)
@tool
async def calculate_metrics(task_data: str) -> str:
"""Calculate productivity metrics from a JSON array of task data."""
return json.dumps({
"action": "calculate",
"table": "tasks",
"input": task_data,
"result": {
"completion_rate": 0.0,
"overdue_count": 0,
"avg_priority": "medium",
},
})
@tool
async def generate_report(period: str, data: str) -> str:
"""Generate a structured report for a time period (e.g. 'last_7_days', 'last_month')."""
return json.dumps({
"action": "report",
"period": period,
"input": data,
})
@tool
async def trend_analysis(data_points: str) -> str:
"""Analyse trends in a JSON array of time-series data points."""
return json.dumps({
"action": "trend",
"input": data_points,
"result": {"trend": "stable", "anomalies": []},
})
@registry.register
class AnalyticsAgent(ChatAgent):
def get_name(self) -> str:
return "analytics_agent"
def get_description(self) -> str:
return "Workspace analytics: metrics, reports, trends"
def get_tools(self) -> list[Any]:
return [calculate_metrics, generate_report, trend_analysis]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

View File

@@ -1,76 +0,0 @@
"""Calendar agent — events, conflict detection, and scheduling."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from app.config.settings import settings
from app.core.agent_registry import ChatAgent, registry
_SYSTEM_PROMPT = (
"You are a calendar management assistant. Help the user manage events, "
"detect scheduling conflicts, and suggest reschedules.\n"
"Rules:\n"
" - Work exclusively with event metadata provided in context\n"
" - Never store or reference raw calendar data\n"
" - date_range format: ISO 8601 interval, e.g. '2024-01-01/2024-01-07'\n"
" - Always confirm the date/time scope of any operation"
)
@tool
async def list_events(date_range: str) -> str:
"""List calendar events in a date range (ISO 8601 interval, e.g. '2024-01-01/2024-01-07')."""
return json.dumps({
"action": "list",
"table": "events",
"filters": {"date_range": date_range},
})
@tool
async def detect_conflicts(events: str) -> str:
"""Detect scheduling conflicts in a JSON array of event metadata objects."""
return json.dumps({
"action": "analyse",
"table": "events",
"input": events,
"result": "conflicts_detected",
})
@tool
async def suggest_reschedule(conflict: str) -> str:
"""Suggest a reschedule for a conflicting event. Pass the conflict as a JSON string."""
return json.dumps({
"action": "suggest_reschedule",
"table": "events",
"input": conflict,
})
@registry.register
class CalendarAgent(ChatAgent):
def get_name(self) -> str:
return "calendar_agent"
def get_description(self) -> str:
return "Calendar management: events, conflicts, scheduling"
def get_tools(self) -> list[Any]:
return [list_events, detect_conflicts, suggest_reschedule]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

View File

@@ -0,0 +1,122 @@
"""Checkpoint agent — project milestone management (list, create, update, delete)."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from app.config.settings import settings
from app.core.agent_registry import ChatAgent, registry
_SYSTEM_PROMPT = (
"You are a project checkpoint assistant. Checkpoints are milestone dates that\n"
"track progress on a project — they are not calendar events.\n\n"
"Rules:\n"
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
" - is_ai_suggested: 1 when proactively proposing a checkpoint, 0 otherwise\n"
" - is_approved: 0 until the user explicitly confirms; then 1\n"
" - For update_checkpoint, use -1 for integer fields you do not want to change\n"
" - Listing without a project_id returns all checkpoints across projects\n"
" - Always echo the title and formatted date in your confirmation."
)
@tool
async def list_checkpoints(project_id: str = "") -> str:
"""List checkpoints. Provide project_id to scope to a specific project."""
return json.dumps({
"action": "list",
"table": "checkpoints",
"filters": {"projectId": project_id or None},
})
@tool
async def create_checkpoint(
project_id: str,
title: str,
date: int,
is_ai_suggested: int = 0,
is_approved: int = 0,
) -> str:
"""Create a project checkpoint (milestone).
project_id: REQUIRED UUID of the parent project
title: descriptive name for the milestone
date: Unix timestamp in milliseconds
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
is_approved: 0 until the user confirms
"""
return json.dumps({
"action": "create_record",
"table": "checkpoints",
"data": {
"projectId": project_id,
"title": title,
"date": date,
"isAiSuggested": is_ai_suggested,
"isApproved": is_approved,
},
})
@tool
async def update_checkpoint(
checkpoint_id: str,
title: str = "",
date: int = -1,
is_approved: int = -1,
) -> str:
"""Update a checkpoint. Only pass fields that should change.
checkpoint_id: UUID of the checkpoint (required)
date: -1 means unchanged; any other value sets the new date (ms timestamp)
is_approved: -1 means unchanged; 0 or 1 sets the approval state
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if date != -1:
updates["date"] = date
if is_approved != -1:
updates["isApproved"] = is_approved
return json.dumps({
"action": "update_record",
"table": "checkpoints",
"data": {"id": checkpoint_id, "updates": updates},
})
@tool
async def delete_checkpoint(checkpoint_id: str) -> str:
"""Delete a checkpoint permanently by its UUID."""
return json.dumps({
"action": "delete_record",
"table": "checkpoints",
"data": {"id": checkpoint_id},
})
@registry.register
class CheckpointAgent(ChatAgent):
def get_name(self) -> str:
return "checkpoint_agent"
def get_description(self) -> str:
return "Manages project checkpoints (milestones): list, create, update, delete"
def get_tools(self) -> list[Any]:
return [list_checkpoints, create_checkpoint, update_checkpoint, delete_checkpoint]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

View File

@@ -1,77 +0,0 @@
"""Email agent — classify, extract action items, draft responses."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from app.config.settings import settings
from app.core.agent_registry import ChatAgent, registry
_SYSTEM_PROMPT = (
"You are an email analysis assistant. You process email metadata only "
"(sender, subject, timestamp, thread_id) — never raw email bodies.\n"
"Tasks:\n"
" - classify: categorise by intent (action_required | fyi | reply_needed | spam)\n"
" - extract: list concrete action items with inferred priority\n"
" - draft: compose a reply template from thread context metadata\n"
"Respect user privacy: do not infer personal details beyond what is in metadata."
)
@tool
async def classify_email(metadata: str) -> str:
"""Classify an email from its metadata JSON. Returns category and confidence score."""
return json.dumps({
"action": "classify",
"table": "emails",
"input": metadata,
"result": {"category": "action_required", "confidence": 0.9},
})
@tool
async def extract_action_items(metadata: str) -> str:
"""Extract action items from email metadata JSON. Returns a list of task descriptions."""
return json.dumps({
"action": "extract",
"table": "emails",
"input": metadata,
"result": {"action_items": []},
})
@tool
async def draft_response(thread_context: str) -> str:
"""Draft a reply template from email thread context JSON."""
return json.dumps({
"action": "draft",
"table": "emails",
"input": thread_context,
})
@registry.register
class EmailAgent(ChatAgent):
def get_name(self) -> str:
return "email_agent"
def get_description(self) -> str:
return "Email analysis: classify, extract actions, draft responses"
def get_tools(self) -> list[Any]:
return [classify_email, extract_action_items, draft_response]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

123
app/agents/note_agent.py Normal file
View File

@@ -0,0 +1,123 @@
"""Note agent — Markdown note management (list, get, create, update, delete)."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from app.config.settings import settings
from app.core.agent_registry import ChatAgent, registry
_SYSTEM_PROMPT = (
"You are a note-taking assistant. You help users create, retrieve, update,\n"
"and delete Markdown notes in their workspace.\n\n"
"Rules:\n"
" - content is always Markdown; preserve formatting when updating\n"
" - project_id is optional; link a note to a project when mentioned\n"
" - When updating, call get_note first if you need to read existing content\n"
" before appending or replacing sections\n"
" - list_notes without project_id returns all notes; scope with project_id\n"
" when the user is working within a specific project\n"
" - Do not fabricate note content — reflect what the user provides or what\n"
" is already in the note (retrieved via get_note)."
)
@tool
async def list_notes(project_id: str = "") -> str:
"""List notes, optionally scoped to a project by project_id."""
return json.dumps({
"action": "list",
"table": "notes",
"filters": {"projectId": project_id or None},
})
@tool
async def get_note(note_id: str) -> str:
"""Fetch a single note by its UUID to read its full Markdown content."""
return json.dumps({
"action": "get",
"table": "notes",
"data": {"id": note_id},
})
@tool
async def create_note(
title: str,
content: str,
project_id: str = "",
) -> str:
"""Create a new note.
title: note heading (required)
content: Markdown body text (required)
project_id: optional UUID linking this note to a project
"""
return json.dumps({
"action": "create_record",
"table": "notes",
"data": {
"title": title,
"content": content,
"projectId": project_id or None,
},
})
@tool
async def update_note(
note_id: str,
title: str = "",
content: str = "",
) -> str:
"""Update an existing note. Only pass fields that should change.
note_id: UUID of the note (required)
If you need to preserve existing content, call get_note first.
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if content:
updates["content"] = content
return json.dumps({
"action": "update_record",
"table": "notes",
"data": {"id": note_id, "updates": updates},
})
@tool
async def delete_note(note_id: str) -> str:
"""Delete a note permanently by its UUID."""
return json.dumps({
"action": "delete_record",
"table": "notes",
"data": {"id": note_id},
})
@registry.register
class NoteAgent(ChatAgent):
def get_name(self) -> str:
return "note_agent"
def get_description(self) -> str:
return "Manages notes: list, get, create, update, delete"
def get_tools(self) -> list[Any]:
return [list_notes, get_note, create_note, update_note, delete_note]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

158
app/agents/project_agent.py Normal file
View File

@@ -0,0 +1,158 @@
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from app.config.settings import settings
from app.core.agent_registry import ChatAgent, registry
_SYSTEM_PROMPT = (
"You are a project management assistant. You help users create, find,\n"
"update, and archive projects in their workspace.\n\n"
"Rules:\n"
" - status must be one of: active, archived\n"
" - client_id is optional; link to a client only when explicitly mentioned\n"
" - ai_summary is populated only when the user asks for a project summary;\n"
" derive it from context data — do not fabricate content\n"
" - Use list_projects for scoped queries; list_all_projects only when the\n"
" user wants a complete cross-client view including archived projects\n"
" - get_project requires a project UUID; resolve the ID first by calling\n"
" list_projects if you only have a project name\n"
" - Prefer archiving (update_project status=archived) over deletion;\n"
" only call delete_project when the user explicitly confirms deletion."
)
@tool
async def list_projects(
client_id: str = "",
include_archived: int = 0,
) -> str:
"""List projects, optionally filtered by client_id.
include_archived: 1 to include archived projects, 0 for active only (default).
"""
return json.dumps({
"action": "list",
"table": "projects",
"filters": {
"clientId": client_id or None,
"includeArchived": bool(include_archived),
},
})
@tool
async def list_all_projects() -> str:
"""List every project regardless of client or status.
Use only when the user wants a complete cross-client overview.
"""
return json.dumps({
"action": "list_all",
"table": "projects",
})
@tool
async def get_project(project_id: str) -> str:
"""Fetch a single project by its UUID."""
return json.dumps({
"action": "get",
"table": "projects",
"data": {"id": project_id},
})
@tool
async def create_project(
name: str,
client_id: str = "",
) -> str:
"""Create a new project.
name: human-readable project name (required)
client_id: optional UUID of the owning client
"""
return json.dumps({
"action": "create_record",
"table": "projects",
"data": {
"name": name,
"clientId": client_id or None,
},
})
@tool
async def update_project(
project_id: str,
name: str = "",
client_id: str = "",
status: str = "",
ai_summary: str = "",
) -> str:
"""Update a project. Only pass fields that should change.
project_id: UUID of the project (required)
status: active | archived
ai_summary: AI-generated summary text (populate only when explicitly requested)
"""
updates: dict[str, Any] = {}
if name:
updates["name"] = name
if client_id:
updates["clientId"] = client_id
if status:
updates["status"] = status
if ai_summary:
updates["aiSummary"] = ai_summary
return json.dumps({
"action": "update_record",
"table": "projects",
"data": {"id": project_id, "updates": updates},
})
@tool
async def delete_project(project_id: str) -> str:
"""Permanently delete a project and orphan its tasks.
IMPORTANT: prefer update_project(status='archived') unless the user
has explicitly confirmed they want permanent deletion.
"""
return json.dumps({
"action": "delete_record",
"table": "projects",
"data": {"id": project_id},
})
@registry.register
class ProjectAgent(ChatAgent):
def get_name(self) -> str:
return "project_agent"
def get_description(self) -> str:
return "Manages projects: list, get, create, update, archive, delete"
def get_tools(self) -> list[Any]:
return [
list_projects,
list_all_projects,
get_project,
create_project,
update_project,
delete_project,
]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

View File

@@ -1,4 +1,4 @@
"""Task agent — create, update, list, and suggest tasks.""" """Task agent — full CRUD for tasks and task comments."""
from __future__ import annotations from __future__ import annotations
@@ -13,40 +13,121 @@ from app.config.settings import settings
from app.core.agent_registry import ChatAgent, registry from app.core.agent_registry import ChatAgent, registry
_SYSTEM_PROMPT = ( _SYSTEM_PROMPT = (
"You are a task management assistant (PM-oriented). Help the user create, " "You are a task management assistant for a project workspace.\n"
"update, list, and suggest tasks.\n" "You create, update, list, and track tasks and their comments.\n\n"
"Rules:\n" "Rules:\n"
" - priority must be one of: low, medium, high, urgent\n" " - status must be one of: todo, in_progress, done\n"
" - infer priority from context clues (deadlines, urgency language, dependencies)\n" " - priority must be one of: high, medium, low\n"
" - due_date as ISO 8601 string when provided\n" " - due_date is a Unix timestamp in milliseconds; convert human dates\n"
" - context fields beyond user_profile are optional; use them when present\n" " - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
"Use the available tools to act, then confirm what was done in plain language." " - project_id is optional; link to a project when the user mentions one\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
" did not explicitly request; 0 otherwise\n"
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
" - Use list_tasks_due_today for 'what's due today' queries\n"
" - For update_task, use -1 for integer fields you do not want to change\n"
" - Always confirm the action in plain, user-friendly language."
) )
# ── Task tools ────────────────────────────────────────────────────────
@tool
async def list_tasks(
project_id: str = "",
status: str = "",
search: str = "",
order_by: str = "",
) -> str:
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
a search string, or an order_by field name (dueDate|priority|createdAt)."""
return json.dumps({
"action": "list",
"table": "tasks",
"filters": {
"projectId": project_id or None,
"status": status or None,
"search": search or None,
"orderBy": order_by or None,
},
})
@tool @tool
async def create_task( async def create_task(
title: str, title: str,
description: str = "", description: str = "",
status: str = "todo",
priority: str = "medium", priority: str = "medium",
due_date: str = "", assignees: str = "[]",
due_date: int = 0,
project_id: str = "",
is_ai_suggested: int = 0,
is_approved: int = 0,
) -> str: ) -> str:
"""Create a new task. priority: low | medium | high | urgent. due_date: ISO 8601.""" """Create a new task.
title: task title (required)
description: optional details
status: todo | in_progress | done (default: todo)
priority: high | medium | low (default: medium)
assignees: JSON-encoded array of assignee names, e.g. '["Alice"]'
due_date: Unix timestamp in milliseconds; 0 means no due date
project_id: optional UUID of the parent project
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
is_approved: 0 until the user confirms; 1 when confirmed
"""
return json.dumps({ return json.dumps({
"action": "create_record", "action": "create_record",
"table": "tasks", "table": "tasks",
"data": { "data": {
"title": title, "title": title,
"description": description, "description": description or None,
"status": status,
"priority": priority, "priority": priority,
"due_date": due_date, "assignee": assignees,
"dueDate": due_date or None,
"projectId": project_id or None,
"isAiSuggested": is_ai_suggested,
"isApproved": is_approved,
}, },
}) })
@tool @tool
async def update_task(task_id: str, updates: str) -> str: async def update_task(
"""Update fields on an existing task. Pass updates as a JSON string, e.g. '{"priority":"high"}'.""" task_id: str,
title: str = "",
description: str = "",
status: str = "",
priority: str = "",
assignees: str = "",
due_date: int = -1,
project_id: str = "",
is_approved: int = -1,
) -> str:
"""Update fields on an existing task. Only pass fields you want to change.
task_id: the task's UUID (required)
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
is_approved: -1 means unchanged; 0 or 1 sets the value
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if description:
updates["description"] = description
if status:
updates["status"] = status
if priority:
updates["priority"] = priority
if assignees:
updates["assignee"] = assignees
if due_date != -1:
updates["dueDate"] = due_date or None
if project_id:
updates["projectId"] = project_id
if is_approved != -1:
updates["isApproved"] = is_approved
return json.dumps({ return json.dumps({
"action": "update_record", "action": "update_record",
"table": "tasks", "table": "tasks",
@@ -55,35 +136,87 @@ async def update_task(task_id: str, updates: str) -> str:
@tool @tool
async def list_tasks(status: str = "", priority: str = "") -> str: async def delete_task(task_id: str) -> str:
"""List tasks. Optionally filter by status (open|done|archived) or priority level.""" """Delete a task permanently by its UUID."""
return json.dumps({ return json.dumps({
"action": "list", "action": "delete_record",
"table": "tasks", "table": "tasks",
"filters": {"status": status, "priority": priority}, "data": {"id": task_id},
}) })
@tool @tool
async def suggest_tasks(context: str) -> str: async def list_tasks_due_today() -> str:
"""Suggest new tasks based on notes or free-form context text.""" """List all tasks whose due date falls on today's date."""
return json.dumps({ return json.dumps({
"action": "suggest", "action": "list_due_today",
"table": "tasks", "table": "tasks",
"context": context,
}) })
# ── Task comment tools ────────────────────────────────────────────────
@tool
async def list_task_comments(task_id: str) -> str:
"""List all comments on a task by its UUID."""
return json.dumps({
"action": "list",
"table": "taskComments",
"filters": {"taskId": task_id},
})
@tool
async def add_task_comment(task_id: str, author: str, content: str) -> str:
"""Add a comment to a task.
task_id: UUID of the task to comment on
author: name or ID of the comment author
content: comment text
"""
return json.dumps({
"action": "create_record",
"table": "taskComments",
"data": {
"taskId": task_id,
"author": author,
"content": content,
},
})
@tool
async def delete_task_comment(comment_id: str) -> str:
"""Delete a task comment by its UUID."""
return json.dumps({
"action": "delete_record",
"table": "taskComments",
"data": {"id": comment_id},
})
# ── Agent ─────────────────────────────────────────────────────────────
@registry.register @registry.register
class TaskAgent(ChatAgent): class TaskAgent(ChatAgent):
def get_name(self) -> str: def get_name(self) -> str:
return "task_agent" return "task_agent"
def get_description(self) -> str: def get_description(self) -> str:
return "Manages tasks: create, update, list, suggest" return "Manages tasks and comments: list, create, update, delete, due-today, comments"
def get_tools(self) -> list[Any]: def get_tools(self) -> list[Any]:
return [create_task, update_task, list_tasks, suggest_tasks] return [
list_tasks,
create_task,
update_task,
delete_task,
list_tasks_due_today,
list_task_comments,
add_task_comment,
delete_task_comment,
]
async def handle(self, query: str, context: dict[str, Any]) -> str: async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY) llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)

46
app/api/deps.py Normal file
View File

@@ -0,0 +1,46 @@
"""Shared FastAPI dependencies.
``get_current_user`` decodes the Bearer JWT and returns a ``UserProfile``.
Step 9 will layer rate-limiting and sanitization middleware on top of this.
Step 12 will add a DB look-up to fetch the live tier from PostgreSQL.
"""
from __future__ import annotations
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from app.config.settings import settings
from app.schemas import BillingTier, UserProfile
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
) -> UserProfile:
"""Validate a Bearer JWT and return the authenticated user.
Raises ``HTTP 401`` on any invalid or expired token.
The tier embedded in the JWT is used for feature-gating until Step 12
adds a live DB lookup.
"""
credentials_exc = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
tier: str = payload.get("tier", "free")
if not user_id or not email:
raise credentials_exc
except JWTError:
raise credentials_exc
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]

118
app/api/routes/auth.py Normal file
View File

@@ -0,0 +1,118 @@
"""Auth routes: register, login, refresh, me.
Users and refresh tokens are kept in an in-memory dict until Step 12
migrates them to PostgreSQL.
"""
from __future__ import annotations
import time
import uuid
from typing import Any
import bcrypt
from fastapi import APIRouter, Depends, HTTPException, status
from jose import jwt
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.config.settings import settings
from app.schemas import AuthTokens, UserProfile
router = APIRouter(prefix="/auth", tags=["auth"])
# ── In-memory stores (replaced by PostgreSQL in Step 12) ─────────────
_users: dict[str, dict[str, Any]] = {} # email → user record
_refresh_tokens: dict[str, str] = {} # plain token → user_id
# ── Internal helpers ─────────────────────────────────────────────────
def _hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def _verify_password(password: str, hashed: str) -> bool:
return bcrypt.checkpw(password.encode(), hashed.encode())
def _make_tokens(user_id: str, email: str, tier: str) -> AuthTokens:
now = int(time.time())
access_exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
access_payload = {
"sub": user_id,
"email": email,
"tier": tier,
"exp": access_exp,
"iat": now,
}
access_token = jwt.encode(
access_payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
)
refresh_token = str(uuid.uuid4())
_refresh_tokens[refresh_token] = user_id
return AuthTokens(
access_token=access_token,
refresh_token=refresh_token,
expires_at=access_exp * 1000, # milliseconds for client
)
# ── Request bodies ────────────────────────────────────────────────────
class _RegisterRequest(BaseModel):
email: str
password: str
class _LoginRequest(BaseModel):
email: str
password: str
class _RefreshRequest(BaseModel):
refresh_token: str
# ── Routes ────────────────────────────────────────────────────────────
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
async def register(body: _RegisterRequest) -> AuthTokens:
"""Create a new account and return JWT tokens."""
if body.email in _users:
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
user_id = str(uuid.uuid4())
_users[body.email] = {
"id": user_id,
"email": body.email,
"password_hash": _hash_password(body.password),
"tier": "free",
}
return _make_tokens(user_id, body.email, "free")
@router.post("/login", response_model=AuthTokens)
async def login(body: _LoginRequest) -> AuthTokens:
"""Validate credentials and return JWT tokens."""
user = _users.get(body.email)
if not user or not _verify_password(body.password, user["password_hash"]):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
return _make_tokens(user["id"], user["email"], user["tier"])
@router.post("/refresh", response_model=AuthTokens)
async def refresh(body: _RefreshRequest) -> AuthTokens:
"""Rotate a refresh token and return a new token pair."""
user_id = _refresh_tokens.pop(body.refresh_token, None)
if user_id is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
user = next((u for u in _users.values() if u["id"] == user_id), None)
if user is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
return _make_tokens(user["id"], user["email"], user["tier"])
@router.get("/me", response_model=UserProfile)
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
"""Return the profile for the authenticated user."""
return current_user

View File

@@ -17,6 +17,11 @@ class Settings(BaseSettings):
AWS_ACCESS_KEY_ID: str = "" AWS_ACCESS_KEY_ID: str = ""
AWS_SECRET_ACCESS_KEY: str = "" AWS_SECRET_ACCESS_KEY: str = ""
PINECONE_API_KEY: str = ""
PINECONE_INDEX: str = "adiuva"
QDRANT_URL: str = ""
QDRANT_API_KEY: str = ""
OPENAI_API_KEY: str = "" OPENAI_API_KEY: str = ""
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"] CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]

View File

@@ -156,29 +156,33 @@ def _register_builtin_templates() -> None:
_tpls: dict[str, str] = { _tpls: dict[str, str] = {
"tpl_task_agent_default": ( "tpl_task_agent_default": (
"You are a task management assistant. Help the user create, update, " "You are a task management assistant. Help the user create, update, "
"and prioritize tasks based on their message and context." "list, and track tasks. Use correct status values (todo, in_progress, "
"done) and priority values (high, medium, low) from the workspace model."
), ),
"tpl_calendar_agent_default": ( "tpl_checkpoint_agent_default": (
"You are a calendar assistant. Help manage events, detect scheduling " "You are a project checkpoint assistant. Help the user create and manage "
"conflicts, and suggest improvements based on the provided context." "milestone checkpoints on their projects. Every checkpoint requires a "
"project_id and a date expressed as a Unix timestamp in milliseconds."
), ),
"tpl_email_agent_default": ( "tpl_project_agent_default": (
"You are an email analysis assistant. Classify emails, extract action " "You are a project management assistant. Help the user create, find, "
"items, and draft responses using only the metadata provided." "update, and archive projects. Projects have a name, an optional client, "
"and a status of either active or archived."
), ),
"tpl_analytics_agent_default": ( "tpl_note_agent_default": (
"You are a workspace analytics assistant. Calculate metrics, generate " "You are a note-taking assistant. Help the user create, retrieve, update, "
"reports, and surface trends from the data provided in context." "and delete Markdown notes. Notes can optionally be linked to a project."
), ),
"tpl_email_extract_action_items": ( "tpl_task_extract_from_project": (
"Extract all action items from the provided email metadata. " "Extract all actionable tasks from the provided project context. "
"Return a structured list of tasks, each with a title, inferred " "Return a structured list of tasks, each with a title, inferred priority "
"priority, and suggested due date where possible." "(high, medium, or low), suggested status (todo), and a due_date in "
"milliseconds where a deadline can be inferred."
), ),
"tpl_analytics_weekly_summary": ( "tpl_note_weekly_summary": (
"Generate a weekly performance summary from the provided analytics " "Generate a weekly project summary note from the provided workspace data. "
"data. Include task completion rate, overdue item count, top " "Include: tasks completed this week, tasks due soon, active projects, "
"priorities for the coming week, and notable trends." "and upcoming checkpoints. Format the output as clean Markdown."
), ),
} }
for tid, text in _tpls.items(): for tid, text in _tpls.items():
@@ -189,20 +193,20 @@ def _load_playbooks() -> None:
"""Pre-build and cache the built-in playbooks.""" """Pre-build and cache the built-in playbooks."""
playbooks: list[tuple[str, ExecutionPlan]] = [ playbooks: list[tuple[str, ExecutionPlan]] = [
( (
"create_task_from_email", "create_tasks_from_project",
ExecutionPlanBuilder("email_agent") ExecutionPlanBuilder("project_agent")
.add_llm_step( .add_llm_step(
"tpl_email_extract_action_items", "tpl_task_extract_from_project",
{"source": "email_metadata"}, {"source": "project_context"},
) )
.add_data_step("create_record", data_from_step=0) .add_data_step("create_record", data_from_step=0)
.build(), .build(),
), ),
( (
"generate_weekly_report", "generate_weekly_note",
ExecutionPlanBuilder("analytics_agent") ExecutionPlanBuilder("note_agent")
.add_llm_step( .add_llm_step(
"tpl_analytics_weekly_summary", "tpl_note_weekly_summary",
{"period": "last_7_days"}, {"period": "last_7_days"},
) )
.add_data_step("create_record", data_from_step=0) .add_data_step("create_record", data_from_step=0)

View File

@@ -82,3 +82,76 @@ class BackupMetadata(BaseModel):
timestamp: int timestamp: int
checksum: str checksum: str
chunk_count: int chunk_count: int
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
class StorageRecord(BaseModel):
id: str
user_id: str
table: str
blob: bytes
checksum: str
created_at: int
updated_at: int
class StorageRecordCreate(BaseModel):
table: str
blob: bytes
checksum: str
class StorageRecordUpdate(BaseModel):
blob: bytes
checksum: str
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
class VectorItem(BaseModel):
id: str
blob: bytes # encrypted vector + metadata — backend never decrypts
checksum: str
class VectorUpsertRequest(BaseModel):
vectors: list[VectorItem]
class VectorSearchRequest(BaseModel):
query_blob: bytes # encrypted query — backend never decrypts
top_k: int = 10
class VectorSearchResult(BaseModel):
id: str
score: float
blob: bytes
class VectorSearchResponse(BaseModel):
results: list[VectorSearchResult]
# ── Plugin Marketplace ────────────────────────────────────────────────
class PluginManifest(BaseModel):
id: str
name: str
description: str
version: str
author: str
permissions: list[str]
category: str
price_cents: int = 0
class PluginListResponse(BaseModel):
plugins: list[PluginManifest]
total: int
page: int
class PluginInstallRequest(BaseModel):
plugin_id: str

1
app/storage/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Cloud storage layer — E2E encrypted blobs and vectors."""

105
app/storage/blob_store.py Normal file
View File

@@ -0,0 +1,105 @@
"""S3-backed store for E2E-encrypted blobs.
Keys are structured as ``{user_id}/{table}/{record_id}``.
The backend never inspects blob content — it stores and retrieves opaque bytes.
"""
from __future__ import annotations
from typing import Any
import boto3
from botocore.exceptions import ClientError
from app.config.settings import settings
class BlobStore:
"""Thin wrapper around boto3 S3.
All blobs must be E2E encrypted by the client before upload.
The backend adds SSE-S3 as an extra layer of at-rest encryption
but cannot decrypt the inner client-side payload.
"""
def _client(self) -> Any:
return boto3.client(
"s3",
region_name=settings.S3_REGION,
aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
)
@staticmethod
def _key(user_id: str, table: str, record_id: str) -> str:
return f"{user_id}/{table}/{record_id}"
async def upload(
self,
user_id: str,
table: str,
record_id: str,
blob: bytes,
checksum: str,
) -> str:
"""Store *blob* in S3 and return the S3 key.
Args:
user_id: Owner of the blob (used as key prefix).
table: Logical table name (e.g. ``"tasks"``).
record_id: Record UUID.
blob: Raw bytes (pre-encrypted by client).
checksum: SHA-256 hex digest supplied by the client; stored as
object metadata for download-time verification.
Returns:
The S3 key under which the blob was stored.
"""
key = self._key(user_id, table, record_id)
self._client().put_object(
Bucket=settings.S3_BUCKET,
Key=key,
Body=blob,
ServerSideEncryption="AES256", # SSE-S3 at rest
Metadata={"checksum": checksum},
)
return key
async def download(self, user_id: str, s3_key: str) -> bytes:
"""Retrieve the blob stored at *s3_key*.
*user_id* is retained in the signature so higher-level code can
enforce ownership without re-parsing the key.
Raises:
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
object does not exist.
"""
response = self._client().get_object(
Bucket=settings.S3_BUCKET,
Key=s3_key,
)
return response["Body"].read()
async def delete(self, user_id: str, s3_key: str) -> None:
"""Delete the object at *s3_key*.
S3 ``delete_object`` is idempotent — it succeeds even if the key does
not exist.
"""
self._client().delete_object(
Bucket=settings.S3_BUCKET,
Key=s3_key,
)
async def list_keys(self, user_id: str, table: str) -> list[str]:
"""Return all S3 keys for a given user + table combination.
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
"""
prefix = f"{user_id}/{table}/"
response = self._client().list_objects_v2(
Bucket=settings.S3_BUCKET,
Prefix=prefix,
)
return [obj["Key"] for obj in response.get("Contents", [])]

32
app/storage/encryption.py Normal file
View File

@@ -0,0 +1,32 @@
"""Integrity verification only — the backend NEVER decrypts user data."""
from __future__ import annotations
import hashlib
import hmac
from fastapi import HTTPException
def verify_checksum(blob: bytes, checksum: str) -> bool:
"""Return ``True`` if SHA-256(blob) matches *checksum*.
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
timing-based side-channel attacks.
"""
computed = hashlib.sha256(blob).hexdigest()
return hmac.compare_digest(computed, checksum)
def reject_if_tampered(blob: bytes, checksum: str) -> None:
"""Raise ``HTTP 400`` if the blob does not match its checksum.
Call this before storing or forwarding any client-provided blob.
The backend never holds decryption keys — this check only verifies
that the opaque bytes arrived intact.
"""
if not verify_checksum(blob, checksum):
raise HTTPException(
status_code=400,
detail="Checksum mismatch: blob integrity check failed",
)

205
app/storage/vector_store.py Normal file
View File

@@ -0,0 +1,205 @@
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
Vectors are pre-encrypted blobs from the client. The backend stores them
alongside a deterministic 32-dim float representation derived from the blob's
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
is a known trade-off documented in the backend plan.
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
``user_id`` payload field on a shared collection.
"""
from __future__ import annotations
import base64
import hashlib
from typing import Any
from pinecone import Pinecone
from qdrant_client import QdrantClient
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
from app.config.settings import settings
from app.schemas import VectorItem, VectorSearchResult
_QDRANT_COLLECTION = "adiuva_vectors"
def _blob_to_vector(blob: bytes) -> list[float]:
"""Derive a 32-dim float vector from *blob* for storage purposes only.
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
normalises each byte to the range [-1.0, 1.0]. This vector carries no
semantic meaning on encrypted data.
"""
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
class VectorStore:
"""Thin wrapper around Pinecone or Qdrant.
The backend to use is selected at runtime:
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
"""
def _use_pinecone(self) -> bool:
return bool(settings.PINECONE_API_KEY)
# ── Pinecone helpers ──────────────────────────────────────────────
def _pinecone_index(self) -> Any:
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
return pc.Index(settings.PINECONE_INDEX)
# ── Qdrant helpers ────────────────────────────────────────────────
def _qdrant_client(self) -> Any:
return QdrantClient(
url=settings.QDRANT_URL,
api_key=settings.QDRANT_API_KEY or None,
)
# ── Public API ────────────────────────────────────────────────────
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
"""Store encrypted vectors in the backend.
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
so it can be returned verbatim during search.
Args:
user_id: Used as Pinecone namespace or Qdrant payload field.
vectors: List of encrypted vector items from the client.
"""
if self._use_pinecone():
await self._pinecone_upsert(user_id, vectors)
else:
await self._qdrant_upsert(user_id, vectors)
async def search(
self,
user_id: str,
query_blob: bytes,
top_k: int,
) -> list[VectorSearchResult]:
"""Query the vector store and return encrypted result blobs.
The query vector is derived from *query_blob* using the same
deterministic mapping as upsert.
Args:
user_id: Scopes the search to this user's namespace.
query_blob: Encrypted query from the client.
top_k: Maximum number of results to return.
Returns:
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
"""
if self._use_pinecone():
return await self._pinecone_search(user_id, query_blob, top_k)
return await self._qdrant_search(user_id, query_blob, top_k)
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
"""Remove vectors by ID, scoped to *user_id*.
Args:
user_id: Namespace / payload filter to prevent cross-user deletion.
vector_ids: List of vector IDs to remove.
"""
if self._use_pinecone():
await self._pinecone_delete(user_id, vector_ids)
else:
await self._qdrant_delete(user_id, vector_ids)
# ── Pinecone implementation ───────────────────────────────────────
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
index = self._pinecone_index()
records = [
{
"id": v.id,
"values": _blob_to_vector(v.blob),
"metadata": {
"blob": base64.b64encode(v.blob).decode(),
"checksum": v.checksum,
"user_id": user_id,
},
}
for v in vectors
]
index.upsert(vectors=records, namespace=user_id)
async def _pinecone_search(
self, user_id: str, query_blob: bytes, top_k: int
) -> list[VectorSearchResult]:
index = self._pinecone_index()
query_vector = _blob_to_vector(query_blob)
response = index.query(
vector=query_vector,
top_k=top_k,
namespace=user_id,
include_metadata=True,
)
results: list[VectorSearchResult] = []
for match in response.get("matches", []):
blob_bytes = base64.b64decode(match["metadata"]["blob"])
results.append(
VectorSearchResult(
id=match["id"],
score=match["score"],
blob=blob_bytes,
)
)
return results
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
index = self._pinecone_index()
index.delete(ids=vector_ids, namespace=user_id)
# ── Qdrant implementation ─────────────────────────────────────────
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
client = self._qdrant_client()
points = [
PointStruct(
id=v.id,
vector=_blob_to_vector(v.blob),
payload={
"blob": base64.b64encode(v.blob).decode(),
"checksum": v.checksum,
"user_id": user_id,
},
)
for v in vectors
]
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
async def _qdrant_search(
self, user_id: str, query_blob: bytes, top_k: int
) -> list[VectorSearchResult]:
client = self._qdrant_client()
query_vector = _blob_to_vector(query_blob)
hits = client.search(
collection_name=_QDRANT_COLLECTION,
query_vector=query_vector,
query_filter=Filter(
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
),
limit=top_k,
)
return [
VectorSearchResult(
id=str(hit.id),
score=hit.score,
blob=base64.b64decode(hit.payload["blob"]),
)
for hit in hits
]
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
client = self._qdrant_client()
client.delete(
collection_name=_QDRANT_COLLECTION,
points_selector=PointIdsList(points=vector_ids),
)

View File

@@ -17,3 +17,6 @@ httpx>=0.28.0
websockets>=14.0 websockets>=14.0
pytest>=8.0.0 pytest>=8.0.0
pytest-asyncio>=0.24.0 pytest-asyncio>=0.24.0
moto[s3]>=5.0.0
pinecone>=5.0.0
qdrant-client>=1.7.0

View File

@@ -1,4 +1,4 @@
"""Unit tests for all four chat agents with mocked LLM.""" """Unit tests for the four domain-specific chat agents with mocked LLM."""
from __future__ import annotations from __future__ import annotations
@@ -9,9 +9,9 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import app.agents # noqa: F401 — triggers @registry.register decorators import app.agents # noqa: F401 — triggers @registry.register decorators
from app.agents.analytics_agent import AnalyticsAgent from app.agents.checkpoint_agent import CheckpointAgent
from app.agents.calendar_agent import CalendarAgent from app.agents.note_agent import NoteAgent
from app.agents.email_agent import EmailAgent from app.agents.project_agent import ProjectAgent
from app.agents.task_agent import TaskAgent from app.agents.task_agent import TaskAgent
from app.core.agent_registry import registry from app.core.agent_registry import registry
@@ -59,15 +59,15 @@ def _mock_llm_with_tool_call(
class TestAgentRegistration: class TestAgentRegistration:
def test_all_agents_registered(self) -> None: def test_all_agents_registered(self) -> None:
names = {a["name"] for a in registry.list_agents()} names = {a["name"] for a in registry.list_agents()}
assert {"task_agent", "calendar_agent", "email_agent", "analytics_agent"}.issubset( assert {
names "task_agent", "checkpoint_agent", "project_agent", "note_agent"
) }.issubset(names)
def test_registry_returns_correct_types(self) -> None: def test_registry_returns_correct_types(self) -> None:
assert isinstance(registry.get("task_agent"), TaskAgent) assert isinstance(registry.get("task_agent"), TaskAgent)
assert isinstance(registry.get("calendar_agent"), CalendarAgent) assert isinstance(registry.get("checkpoint_agent"), CheckpointAgent)
assert isinstance(registry.get("email_agent"), EmailAgent) assert isinstance(registry.get("project_agent"), ProjectAgent)
assert isinstance(registry.get("analytics_agent"), AnalyticsAgent) assert isinstance(registry.get("note_agent"), NoteAgent)
def test_descriptions_present(self) -> None: def test_descriptions_present(self) -> None:
for agent_info in registry.list_agents(): for agent_info in registry.list_agents():
@@ -82,14 +82,23 @@ class TestTaskAgent:
assert TaskAgent().get_name() == "task_agent" assert TaskAgent().get_name() == "task_agent"
def test_description(self) -> None: def test_description(self) -> None:
assert TaskAgent().get_description() == "Manages tasks: create, update, list, suggest" assert TaskAgent().get_description() == "Manages tasks and comments: list, create, update, delete, due-today, comments"
def test_get_tools_count(self) -> None: def test_get_tools_count(self) -> None:
assert len(TaskAgent().get_tools()) == 4 assert len(TaskAgent().get_tools()) == 8
def test_tool_names(self) -> None: def test_tool_names(self) -> None:
names = {t.name for t in TaskAgent().get_tools()} names = {t.name for t in TaskAgent().get_tools()}
assert names == {"create_task", "update_task", "list_tasks", "suggest_tasks"} assert names == {
"list_tasks",
"create_task",
"update_task",
"delete_task",
"list_tasks_due_today",
"list_task_comments",
"add_task_comment",
"delete_task_comment",
}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_returns_string(self) -> None: async def test_handle_returns_string(self) -> None:
@@ -111,10 +120,10 @@ class TestTaskAgent:
mock_cls.return_value = _mock_llm_with_tool_call( mock_cls.return_value = _mock_llm_with_tool_call(
"create_task", "create_task",
{"title": "Buy groceries", "priority": "low"}, {"title": "Buy groceries", "priority": "low"},
"Task 'Buy groceries' created with low priority.", "Task 'Buy groceries' created.",
) )
result = await TaskAgent().handle("add a grocery task", {}) result = await TaskAgent().handle("add a grocery task", {})
assert result == "Task 'Buy groceries' created with low priority." assert result == "Task 'Buy groceries' created."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None: async def test_handle_accepts_empty_context(self) -> None:
@@ -123,20 +132,11 @@ class TestTaskAgent:
result = await TaskAgent().handle("help", {}) result = await TaskAgent().handle("help", {})
assert isinstance(result, str) assert isinstance(result, str)
@pytest.mark.asyncio
async def test_handle_accepts_partial_context(self) -> None:
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await TaskAgent().handle("list tasks", {"user_profile": {"id": "u1"}})
assert isinstance(result, str)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_accepts_rich_context(self) -> None: async def test_handle_accepts_rich_context(self) -> None:
context = { context = {
"user_profile": {"id": "u1", "tier": "pro"}, "user_profile": {"id": "u1", "tier": "pro"},
"recent_tasks": [{"id": "t1", "title": "Old task"}], "recent_tasks": [{"id": "t1", "title": "Old task"}],
"relevant_documents": ["doc1"],
"extra_plugin_data": {"batch_id": "b1"},
} }
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls: with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
mock_cls.return_value = _mock_llm("Tasks listed.") mock_cls.return_value = _mock_llm("Tasks listed.")
@@ -146,244 +146,475 @@ class TestTaskAgent:
class TestTaskAgentTools: class TestTaskAgentTools:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_task_returns_valid_json(self) -> None: async def test_list_tasks_defaults(self) -> None:
from app.agents.task_agent import list_tasks
result = await list_tasks.ainvoke({})
data = json.loads(result)
assert data["action"] == "list"
assert data["table"] == "tasks"
@pytest.mark.asyncio
async def test_list_tasks_with_status_filter(self) -> None:
from app.agents.task_agent import list_tasks
result = await list_tasks.ainvoke({"status": "done"})
data = json.loads(result)
assert data["filters"]["status"] == "done"
@pytest.mark.asyncio
async def test_create_task_defaults(self) -> None:
from app.agents.task_agent import create_task from app.agents.task_agent import create_task
result = await create_task.ainvoke({"title": "Test task", "priority": "high"}) result = await create_task.ainvoke({"title": "Test task"})
data = json.loads(result) data = json.loads(result)
assert data["action"] == "create_record" assert data["action"] == "create_record"
assert data["table"] == "tasks" assert data["table"] == "tasks"
assert data["data"]["title"] == "Test task" assert data["data"]["title"] == "Test task"
assert data["data"]["priority"] == "high" assert data["data"]["status"] == "todo"
assert data["data"]["priority"] == "medium"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_task_returns_valid_json(self) -> None: async def test_create_task_with_all_fields(self) -> None:
from app.agents.task_agent import create_task
result = await create_task.ainvoke({
"title": "Deploy",
"priority": "high",
"status": "in_progress",
"project_id": "p1",
"is_ai_suggested": 1,
})
data = json.loads(result)
assert data["data"]["priority"] == "high"
assert data["data"]["status"] == "in_progress"
assert data["data"]["projectId"] == "p1"
assert data["data"]["isAiSuggested"] == 1
@pytest.mark.asyncio
async def test_update_task_with_status(self) -> None:
from app.agents.task_agent import update_task from app.agents.task_agent import update_task
result = await update_task.ainvoke( result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
{"task_id": "t1", "updates": '{"priority": "urgent"}'}
)
data = json.loads(result) data = json.loads(result)
assert data["action"] == "update_record" assert data["action"] == "update_record"
assert data["data"]["id"] == "t1" assert data["data"]["id"] == "t1"
assert data["data"]["updates"]["status"] == "done"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_tasks_returns_valid_json(self) -> None: async def test_update_task_empty_updates(self) -> None:
from app.agents.task_agent import list_tasks from app.agents.task_agent import update_task
result = await list_tasks.ainvoke({"status": "open"}) result = await update_task.ainvoke({"task_id": "t1"})
data = json.loads(result) data = json.loads(result)
assert data["action"] == "list" assert data["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_task(self) -> None:
from app.agents.task_agent import delete_task
result = await delete_task.ainvoke({"task_id": "t1"})
data = json.loads(result)
assert data["action"] == "delete_record"
assert data["table"] == "tasks"
assert data["data"]["id"] == "t1"
@pytest.mark.asyncio
async def test_list_tasks_due_today(self) -> None:
from app.agents.task_agent import list_tasks_due_today
result = await list_tasks_due_today.ainvoke({})
data = json.loads(result)
assert data["action"] == "list_due_today"
assert data["table"] == "tasks" assert data["table"] == "tasks"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_suggest_tasks_returns_valid_json(self) -> None: async def test_list_task_comments(self) -> None:
from app.agents.task_agent import suggest_tasks from app.agents.task_agent import list_task_comments
result = await suggest_tasks.ainvoke({"context": "lots of meetings this week"}) result = await list_task_comments.ainvoke({"task_id": "t1"})
data = json.loads(result)
assert data["action"] == "suggest"
# ── CalendarAgent ─────────────────────────────────────────────────────
class TestCalendarAgent:
def test_name(self) -> None:
assert CalendarAgent().get_name() == "calendar_agent"
def test_description(self) -> None:
assert CalendarAgent().get_description() == "Calendar management: events, conflicts, scheduling"
def test_get_tools_count(self) -> None:
assert len(CalendarAgent().get_tools()) == 3
def test_tool_names(self) -> None:
names = {t.name for t in CalendarAgent().get_tools()}
assert names == {"list_events", "detect_conflicts", "suggest_reschedule"}
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.calendar_agent.ChatOpenAI") as mock_cls:
mock_cls.return_value = _mock_llm("No conflicts found.")
result = await CalendarAgent().handle("check my schedule", {})
assert result == "No conflicts found."
@pytest.mark.asyncio
async def test_handle_with_list_events_tool_call(self) -> None:
with patch("app.agents.calendar_agent.ChatOpenAI") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"list_events",
{"date_range": "2024-01-01/2024-01-07"},
"You have 3 events next week.",
)
result = await CalendarAgent().handle("what events do I have?", {})
assert result == "You have 3 events next week."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.calendar_agent.ChatOpenAI") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await CalendarAgent().handle("reschedule meeting", {})
assert isinstance(result, str)
class TestCalendarAgentTools:
@pytest.mark.asyncio
async def test_list_events_returns_valid_json(self) -> None:
from app.agents.calendar_agent import list_events
result = await list_events.ainvoke({"date_range": "2024-01-01/2024-01-07"})
data = json.loads(result) data = json.loads(result)
assert data["action"] == "list" assert data["action"] == "list"
assert data["table"] == "events" assert data["table"] == "taskComments"
assert data["filters"]["date_range"] == "2024-01-01/2024-01-07" assert data["filters"]["taskId"] == "t1"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_detect_conflicts_returns_valid_json(self) -> None: async def test_add_task_comment(self) -> None:
from app.agents.calendar_agent import detect_conflicts from app.agents.task_agent import add_task_comment
result = await detect_conflicts.ainvoke({"events": "[]"}) result = await add_task_comment.ainvoke({
"task_id": "t1",
"author": "Alice",
"content": "Looks good!",
})
data = json.loads(result) data = json.loads(result)
assert data["action"] == "analyse" assert data["action"] == "create_record"
assert data["table"] == "taskComments"
assert data["data"]["taskId"] == "t1"
assert data["data"]["author"] == "Alice"
assert data["data"]["content"] == "Looks good!"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_suggest_reschedule_returns_valid_json(self) -> None: async def test_delete_task_comment(self) -> None:
from app.agents.calendar_agent import suggest_reschedule from app.agents.task_agent import delete_task_comment
result = await suggest_reschedule.ainvoke({"conflict": '{"event": "standup"}'}) result = await delete_task_comment.ainvoke({"comment_id": "c1"})
data = json.loads(result) data = json.loads(result)
assert data["action"] == "suggest_reschedule" assert data["action"] == "delete_record"
assert data["table"] == "taskComments"
assert data["data"]["id"] == "c1"
# ── EmailAgent ──────────────────────────────────────────────────────── # ── CheckpointAgent ───────────────────────────────────────────────────
class TestEmailAgent: class TestCheckpointAgent:
def test_name(self) -> None: def test_name(self) -> None:
assert EmailAgent().get_name() == "email_agent" assert CheckpointAgent().get_name() == "checkpoint_agent"
def test_description(self) -> None: def test_description(self) -> None:
assert EmailAgent().get_description() == "Email analysis: classify, extract actions, draft responses" assert CheckpointAgent().get_description() == "Manages project checkpoints (milestones): list, create, update, delete"
def test_get_tools_count(self) -> None: def test_get_tools_count(self) -> None:
assert len(EmailAgent().get_tools()) == 3 assert len(CheckpointAgent().get_tools()) == 4
def test_tool_names(self) -> None: def test_tool_names(self) -> None:
names = {t.name for t in EmailAgent().get_tools()} names = {t.name for t in CheckpointAgent().get_tools()}
assert names == {"classify_email", "extract_action_items", "draft_response"} assert names == {"list_checkpoints", "create_checkpoint", "update_checkpoint", "delete_checkpoint"}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None: async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.email_agent.ChatOpenAI") as mock_cls: with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
mock_cls.return_value = _mock_llm("Email classified as action_required.") mock_cls.return_value = _mock_llm("No checkpoints found.")
result = await EmailAgent().handle("classify this email", {}) result = await CheckpointAgent().handle("list checkpoints", {})
assert result == "Email classified as action_required." assert result == "No checkpoints found."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_with_classify_tool_call(self) -> None: async def test_handle_with_create_tool_call(self) -> None:
with patch("app.agents.email_agent.ChatOpenAI") as mock_cls: with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call( mock_cls.return_value = _mock_llm_with_tool_call(
"classify_email", "create_checkpoint",
{"metadata": '{"subject": "URGENT: action needed"}'}, {"project_id": "p1", "title": "MVP Launch", "date": 1700000000000},
"This email requires immediate action.", "Checkpoint 'MVP Launch' created.",
) )
result = await EmailAgent().handle("what is this email about?", {}) result = await CheckpointAgent().handle("add MVP checkpoint", {})
assert result == "This email requires immediate action." assert result == "Checkpoint 'MVP Launch' created."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None: async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.email_agent.ChatOpenAI") as mock_cls: with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
mock_cls.return_value = _mock_llm("Done.") mock_cls.return_value = _mock_llm("Done.")
result = await EmailAgent().handle("draft a reply", {}) result = await CheckpointAgent().handle("show milestones", {})
assert isinstance(result, str) assert isinstance(result, str)
class TestEmailAgentTools: class TestCheckpointAgentTools:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_classify_email_returns_valid_json(self) -> None: async def test_list_checkpoints_no_project(self) -> None:
from app.agents.email_agent import classify_email from app.agents.checkpoint_agent import list_checkpoints
result = await classify_email.ainvoke({"metadata": '{"subject": "Meeting"}' }) result = await list_checkpoints.ainvoke({})
data = json.loads(result) data = json.loads(result)
assert data["action"] == "classify" assert data["action"] == "list"
assert "result" in data assert data["table"] == "checkpoints"
assert "category" in data["result"] assert data["filters"]["projectId"] is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_extract_action_items_returns_valid_json(self) -> None: async def test_list_checkpoints_with_project(self) -> None:
from app.agents.email_agent import extract_action_items from app.agents.checkpoint_agent import list_checkpoints
result = await extract_action_items.ainvoke({"metadata": '{"subject": "Follow up"}'}) result = await list_checkpoints.ainvoke({"project_id": "p1"})
data = json.loads(result) data = json.loads(result)
assert data["action"] == "extract" assert data["filters"]["projectId"] == "p1"
assert "action_items" in data["result"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_draft_response_returns_valid_json(self) -> None: async def test_create_checkpoint(self) -> None:
from app.agents.email_agent import draft_response from app.agents.checkpoint_agent import create_checkpoint
result = await draft_response.ainvoke({"thread_context": '{"thread_id": "t1"}'}) result = await create_checkpoint.ainvoke({
"project_id": "p1",
"title": "Beta release",
"date": 1700000000000,
})
data = json.loads(result) data = json.loads(result)
assert data["action"] == "draft" assert data["action"] == "create_record"
assert data["table"] == "checkpoints"
assert data["data"]["projectId"] == "p1"
assert data["data"]["title"] == "Beta release"
assert data["data"]["date"] == 1700000000000
@pytest.mark.asyncio
async def test_create_checkpoint_ai_suggested(self) -> None:
from app.agents.checkpoint_agent import create_checkpoint
result = await create_checkpoint.ainvoke({
"project_id": "p1",
"title": "Review",
"date": 1700000000000,
"is_ai_suggested": 1,
})
data = json.loads(result)
assert data["data"]["isAiSuggested"] == 1
assert data["data"]["isApproved"] == 0
@pytest.mark.asyncio
async def test_update_checkpoint_approve(self) -> None:
from app.agents.checkpoint_agent import update_checkpoint
result = await update_checkpoint.ainvoke({
"checkpoint_id": "c1",
"is_approved": 1,
})
data = json.loads(result)
assert data["action"] == "update_record"
assert data["data"]["id"] == "c1"
assert data["data"]["updates"]["isApproved"] == 1
@pytest.mark.asyncio
async def test_update_checkpoint_empty_updates(self) -> None:
from app.agents.checkpoint_agent import update_checkpoint
result = await update_checkpoint.ainvoke({"checkpoint_id": "c1"})
data = json.loads(result)
assert data["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_checkpoint(self) -> None:
from app.agents.checkpoint_agent import delete_checkpoint
result = await delete_checkpoint.ainvoke({"checkpoint_id": "c1"})
data = json.loads(result)
assert data["action"] == "delete_record"
assert data["table"] == "checkpoints"
assert data["data"]["id"] == "c1"
# ── AnalyticsAgent ──────────────────────────────────────────────────── # ── ProjectAgent ──────────────────────────────────────────────────────
class TestAnalyticsAgent: class TestProjectAgent:
def test_name(self) -> None: def test_name(self) -> None:
assert AnalyticsAgent().get_name() == "analytics_agent" assert ProjectAgent().get_name() == "project_agent"
def test_description(self) -> None: def test_description(self) -> None:
assert AnalyticsAgent().get_description() == "Workspace analytics: metrics, reports, trends" assert ProjectAgent().get_description() == "Manages projects: list, get, create, update, archive, delete"
def test_get_tools_count(self) -> None: def test_get_tools_count(self) -> None:
assert len(AnalyticsAgent().get_tools()) == 3 assert len(ProjectAgent().get_tools()) == 6
def test_tool_names(self) -> None: def test_tool_names(self) -> None:
names = {t.name for t in AnalyticsAgent().get_tools()} names = {t.name for t in ProjectAgent().get_tools()}
assert names == {"calculate_metrics", "generate_report", "trend_analysis"} assert names == {
"list_projects",
"list_all_projects",
"get_project",
"create_project",
"update_project",
"delete_project",
}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None: async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.analytics_agent.ChatOpenAI") as mock_cls: with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
mock_cls.return_value = _mock_llm("Completion rate is 78%.") mock_cls.return_value = _mock_llm("Project Alpha is active.")
result = await AnalyticsAgent().handle("show my metrics", {}) result = await ProjectAgent().handle("show my projects", {})
assert result == "Completion rate is 78%." assert result == "Project Alpha is active."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_with_generate_report_tool_call(self) -> None: async def test_handle_with_create_project_tool_call(self) -> None:
with patch("app.agents.analytics_agent.ChatOpenAI") as mock_cls: with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call( mock_cls.return_value = _mock_llm_with_tool_call(
"generate_report", "create_project",
{"period": "last_7_days", "data": "[]"}, {"name": "Pippo"},
"Weekly report: 12 tasks completed, 2 overdue.", "Project 'Pippo' created.",
) )
result = await AnalyticsAgent().handle("weekly report", {}) result = await ProjectAgent().handle("create project Pippo", {})
assert result == "Weekly report: 12 tasks completed, 2 overdue." assert result == "Project 'Pippo' created."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None: async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.analytics_agent.ChatOpenAI") as mock_cls: with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
mock_cls.return_value = _mock_llm("Done.") mock_cls.return_value = _mock_llm("Done.")
result = await AnalyticsAgent().handle("analyse trends", {}) result = await ProjectAgent().handle("archive old project", {})
assert isinstance(result, str) assert isinstance(result, str)
class TestAnalyticsAgentTools: class TestProjectAgentTools:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_calculate_metrics_returns_valid_json(self) -> None: async def test_list_projects_defaults(self) -> None:
from app.agents.analytics_agent import calculate_metrics from app.agents.project_agent import list_projects
result = await calculate_metrics.ainvoke({"task_data": "[]"}) result = await list_projects.ainvoke({})
data = json.loads(result) data = json.loads(result)
assert data["action"] == "calculate" assert data["action"] == "list"
assert "result" in data assert data["table"] == "projects"
assert "completion_rate" in data["result"] assert data["filters"]["includeArchived"] is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_report_returns_valid_json(self) -> None: async def test_list_projects_include_archived(self) -> None:
from app.agents.analytics_agent import generate_report from app.agents.project_agent import list_projects
result = await generate_report.ainvoke({"period": "last_7_days", "data": "[]"}) result = await list_projects.ainvoke({"include_archived": 1})
data = json.loads(result) data = json.loads(result)
assert data["action"] == "report" assert data["filters"]["includeArchived"] is True
assert data["period"] == "last_7_days"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trend_analysis_returns_valid_json(self) -> None: async def test_list_all_projects(self) -> None:
from app.agents.analytics_agent import trend_analysis from app.agents.project_agent import list_all_projects
result = await trend_analysis.ainvoke({"data_points": "[]"}) result = await list_all_projects.ainvoke({})
data = json.loads(result) data = json.loads(result)
assert data["action"] == "trend" assert data["action"] == "list_all"
assert "result" in data assert data["table"] == "projects"
assert "anomalies" in data["result"]
@pytest.mark.asyncio
async def test_get_project(self) -> None:
from app.agents.project_agent import get_project
result = await get_project.ainvoke({"project_id": "p1"})
data = json.loads(result)
assert data["action"] == "get"
assert data["table"] == "projects"
assert data["data"]["id"] == "p1"
@pytest.mark.asyncio
async def test_create_project_name_only(self) -> None:
from app.agents.project_agent import create_project
result = await create_project.ainvoke({"name": "Alpha"})
data = json.loads(result)
assert data["action"] == "create_record"
assert data["data"]["name"] == "Alpha"
assert data["data"]["clientId"] is None
@pytest.mark.asyncio
async def test_create_project_with_client(self) -> None:
from app.agents.project_agent import create_project
result = await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
data = json.loads(result)
assert data["data"]["clientId"] == "cl1"
@pytest.mark.asyncio
async def test_update_project_archive(self) -> None:
from app.agents.project_agent import update_project
result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
data = json.loads(result)
assert data["action"] == "update_record"
assert data["data"]["id"] == "p1"
assert data["data"]["updates"]["status"] == "archived"
@pytest.mark.asyncio
async def test_update_project_empty_updates(self) -> None:
from app.agents.project_agent import update_project
result = await update_project.ainvoke({"project_id": "p1"})
data = json.loads(result)
assert data["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_project(self) -> None:
from app.agents.project_agent import delete_project
result = await delete_project.ainvoke({"project_id": "p1"})
data = json.loads(result)
assert data["action"] == "delete_record"
assert data["data"]["id"] == "p1"
# ── NoteAgent ─────────────────────────────────────────────────────────
class TestNoteAgent:
def test_name(self) -> None:
assert NoteAgent().get_name() == "note_agent"
def test_description(self) -> None:
assert NoteAgent().get_description() == "Manages notes: list, get, create, update, delete"
def test_get_tools_count(self) -> None:
assert len(NoteAgent().get_tools()) == 5
def test_tool_names(self) -> None:
names = {t.name for t in NoteAgent().get_tools()}
assert names == {"list_notes", "get_note", "create_note", "update_note", "delete_note"}
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
mock_cls.return_value = _mock_llm("Note created.")
result = await NoteAgent().handle("create a note", {})
assert result == "Note created."
@pytest.mark.asyncio
async def test_handle_with_create_note_tool_call(self) -> None:
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"create_note",
{"title": "Daily log", "content": "# Today\nAll good."},
"Note 'Daily log' created.",
)
result = await NoteAgent().handle("log today's progress", {})
assert result == "Note 'Daily log' created."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await NoteAgent().handle("show notes", {})
assert isinstance(result, str)
class TestNoteAgentTools:
@pytest.mark.asyncio
async def test_list_notes_no_project(self) -> None:
from app.agents.note_agent import list_notes
result = await list_notes.ainvoke({})
data = json.loads(result)
assert data["action"] == "list"
assert data["table"] == "notes"
assert data["filters"]["projectId"] is None
@pytest.mark.asyncio
async def test_list_notes_with_project(self) -> None:
from app.agents.note_agent import list_notes
result = await list_notes.ainvoke({"project_id": "p1"})
data = json.loads(result)
assert data["filters"]["projectId"] == "p1"
@pytest.mark.asyncio
async def test_get_note(self) -> None:
from app.agents.note_agent import get_note
result = await get_note.ainvoke({"note_id": "n1"})
data = json.loads(result)
assert data["action"] == "get"
assert data["table"] == "notes"
assert data["data"]["id"] == "n1"
@pytest.mark.asyncio
async def test_create_note_minimal(self) -> None:
from app.agents.note_agent import create_note
result = await create_note.ainvoke({
"title": "Daily log",
"content": "# Today\nAll good.",
})
data = json.loads(result)
assert data["action"] == "create_record"
assert data["table"] == "notes"
assert data["data"]["title"] == "Daily log"
assert data["data"]["content"] == "# Today\nAll good."
assert data["data"]["projectId"] is None
@pytest.mark.asyncio
async def test_create_note_with_project(self) -> None:
from app.agents.note_agent import create_note
result = await create_note.ainvoke({
"title": "Sprint notes",
"content": "## Sprint 1",
"project_id": "p1",
})
data = json.loads(result)
assert data["data"]["projectId"] == "p1"
@pytest.mark.asyncio
async def test_update_note_content_only(self) -> None:
from app.agents.note_agent import update_note
result = await update_note.ainvoke({
"note_id": "n1",
"content": "# Updated content",
})
data = json.loads(result)
assert data["action"] == "update_record"
assert data["data"]["id"] == "n1"
assert data["data"]["updates"]["content"] == "# Updated content"
assert "title" not in data["data"]["updates"]
@pytest.mark.asyncio
async def test_update_note_empty_updates(self) -> None:
from app.agents.note_agent import update_note
result = await update_note.ainvoke({"note_id": "n1"})
data = json.loads(result)
assert data["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_note(self) -> None:
from app.agents.note_agent import delete_note
result = await delete_note.ainvoke({"note_id": "n1"})
data = json.loads(result)
assert data["action"] == "delete_record"
assert data["table"] == "notes"
assert data["data"]["id"] == "n1"

View File

@@ -243,14 +243,14 @@ class TestPlanCache:
class TestModuleSingletons: class TestModuleSingletons:
def test_template_registry_has_all_agent_defaults(self) -> None: def test_template_registry_has_all_agent_defaults(self) -> None:
for agent in ("task_agent", "calendar_agent", "email_agent", "analytics_agent"): for agent in ("task_agent", "checkpoint_agent", "project_agent", "note_agent"):
assert template_registry.has(f"tpl_{agent}_default"), ( assert template_registry.has(f"tpl_{agent}_default"), (
f"Missing template: tpl_{agent}_default" f"Missing template: tpl_{agent}_default"
) )
def test_template_registry_has_operation_templates(self) -> None: def test_template_registry_has_operation_templates(self) -> None:
assert template_registry.has("tpl_email_extract_action_items") assert template_registry.has("tpl_task_extract_from_project")
assert template_registry.has("tpl_analytics_weekly_summary") assert template_registry.has("tpl_note_weekly_summary")
def test_template_registry_get_returns_non_empty_string(self) -> None: def test_template_registry_get_returns_non_empty_string(self) -> None:
text = template_registry.get("tpl_task_agent_default") text = template_registry.get("tpl_task_agent_default")
@@ -260,20 +260,20 @@ class TestModuleSingletons:
def test_plan_cache_has_prebuilt_playbooks(self) -> None: def test_plan_cache_has_prebuilt_playbooks(self) -> None:
assert len(plan_cache.get_all_playbooks()) >= 2 assert len(plan_cache.get_all_playbooks()) >= 2
def test_playbook_create_task_from_email(self) -> None: def test_playbook_create_tasks_from_project(self) -> None:
plan = plan_cache.get_plan("create_task_from_email") plan = plan_cache.get_plan("create_tasks_from_project")
assert plan is not None assert plan is not None
assert plan.agent == "email_agent" assert plan.agent == "project_agent"
assert len(plan.steps) == 2 assert len(plan.steps) == 2
assert plan.steps[0].prompt_template == "tpl_email_extract_action_items" assert plan.steps[0].prompt_template == "tpl_task_extract_from_project"
assert plan.steps[1].data_from_step == 0 assert plan.steps[1].data_from_step == 0
def test_playbook_generate_weekly_report(self) -> None: def test_playbook_generate_weekly_note(self) -> None:
plan = plan_cache.get_plan("generate_weekly_report") plan = plan_cache.get_plan("generate_weekly_note")
assert plan is not None assert plan is not None
assert plan.agent == "analytics_agent" assert plan.agent == "note_agent"
assert len(plan.steps) == 2 assert len(plan.steps) == 2
assert plan.steps[0].prompt_template == "tpl_analytics_weekly_summary" assert plan.steps[0].prompt_template == "tpl_note_weekly_summary"
assert plan.steps[1].data_from_step == 0 assert plan.steps[1].data_from_step == 0
def test_playbook_steps_have_no_raw_prompt_text(self) -> None: def test_playbook_steps_have_no_raw_prompt_text(self) -> None:

385
tests/test_storage.py Normal file
View File

@@ -0,0 +1,385 @@
"""Tests for the storage layer: encryption, BlobStore, and VectorStore."""
from __future__ import annotations
import base64
import hashlib
import os
from unittest.mock import MagicMock, patch
import boto3
import pytest
from botocore.exceptions import ClientError
from moto import mock_aws
from app.storage.encryption import reject_if_tampered, verify_checksum
from app.storage.blob_store import BlobStore
from app.storage.vector_store import VectorStore, _blob_to_vector
from app.schemas import VectorItem, VectorSearchResult
# ── Helpers ───────────────────────────────────────────────────────────
_BLOB = b"encrypted-payload-opaque-to-server"
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
_BUCKET = "test-bucket"
_REGION = "us-east-1"
@pytest.fixture
def s3_bucket():
"""Create a mocked S3 bucket and expose its name."""
with mock_aws():
os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing")
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing")
os.environ.setdefault("AWS_DEFAULT_REGION", _REGION)
client = boto3.client("s3", region_name=_REGION)
client.create_bucket(Bucket=_BUCKET)
with patch("app.storage.blob_store.settings") as mock_settings:
mock_settings.S3_BUCKET = _BUCKET
mock_settings.S3_REGION = _REGION
mock_settings.AWS_ACCESS_KEY_ID = "testing"
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
yield _BUCKET
def _pinecone_mock():
"""Return a mock Pinecone index with realistic return shapes."""
mock_index = MagicMock()
mock_index.query.return_value = {
"matches": [
{
"id": "v1",
"score": 0.95,
"metadata": {
"blob": base64.b64encode(b"result-blob").decode(),
"checksum": hashlib.sha256(b"result-blob").hexdigest(),
"user_id": "u1",
},
}
]
}
mock_pc = MagicMock()
mock_pc.return_value.Index.return_value = mock_index
return mock_pc, mock_index
# ── TestEncryption ────────────────────────────────────────────────────
class TestEncryption:
def test_verify_checksum_correct(self) -> None:
assert verify_checksum(_BLOB, _CHECKSUM) is True
def test_verify_checksum_wrong(self) -> None:
assert verify_checksum(_BLOB, "0" * 64) is False
def test_verify_checksum_empty_checksum(self) -> None:
assert verify_checksum(_BLOB, "") is False
def test_verify_checksum_empty_blob(self) -> None:
expected = hashlib.sha256(b"").hexdigest()
assert verify_checksum(b"", expected) is True
def test_verify_checksum_tampered_blob(self) -> None:
tampered = _BLOB + b"\x00"
assert verify_checksum(tampered, _CHECKSUM) is False
def test_reject_if_tampered_passes_when_valid(self) -> None:
# Should not raise
reject_if_tampered(_BLOB, _CHECKSUM)
def test_reject_if_tampered_raises_400_on_mismatch(self) -> None:
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
reject_if_tampered(_BLOB, "bad" * 20)
assert exc_info.value.status_code == 400
def test_reject_if_tampered_detail_mentions_checksum(self) -> None:
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
reject_if_tampered(_BLOB, "bad" * 20)
assert "checksum" in exc_info.value.detail.lower()
def test_checksum_is_sha256_hex(self) -> None:
cs = hashlib.sha256(_BLOB).hexdigest()
assert len(cs) == 64
assert all(c in "0123456789abcdef" for c in cs)
# ── TestBlobStore ─────────────────────────────────────────────────────
class TestBlobStore:
@pytest.mark.asyncio
async def test_upload_returns_correct_key(self, s3_bucket: str) -> None:
store = BlobStore()
key = await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
assert key == "u1/tasks/r1"
@pytest.mark.asyncio
async def test_upload_object_exists_in_s3(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
# Verify by downloading — no exception means object exists
retrieved = await store.download("u1", "u1/tasks/r1")
assert retrieved == _BLOB
@pytest.mark.asyncio
async def test_download_retrieves_same_bytes(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "notes", "n1", b"note-data", hashlib.sha256(b"note-data").hexdigest())
result = await store.download("u1", "u1/notes/n1")
assert result == b"note-data"
@pytest.mark.asyncio
async def test_delete_removes_object(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
await store.delete("u1", "u1/tasks/r1")
with pytest.raises(ClientError) as exc_info:
await store.download("u1", "u1/tasks/r1")
assert exc_info.value.response["Error"]["Code"] == "NoSuchKey"
@pytest.mark.asyncio
async def test_delete_is_idempotent(self, s3_bucket: str) -> None:
store = BlobStore()
# Delete a key that never existed — should not raise
await store.delete("u1", "u1/tasks/nonexistent")
@pytest.mark.asyncio
async def test_list_keys_returns_correct_keys(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
await store.upload("u1", "tasks", "r2", _BLOB, _CHECKSUM)
keys = await store.list_keys("u1", "tasks")
assert set(keys) == {"u1/tasks/r1", "u1/tasks/r2"}
@pytest.mark.asyncio
async def test_list_keys_scoped_to_table(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
await store.upload("u1", "notes", "n1", _BLOB, _CHECKSUM)
keys = await store.list_keys("u1", "tasks")
assert "u1/notes/n1" not in keys
assert "u1/tasks/r1" in keys
@pytest.mark.asyncio
async def test_list_keys_no_cross_user_leakage(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
await store.upload("u2", "tasks", "r1", _BLOB, _CHECKSUM)
keys_u1 = await store.list_keys("u1", "tasks")
assert "u2/tasks/r1" not in keys_u1
@pytest.mark.asyncio
async def test_list_keys_empty_table(self, s3_bucket: str) -> None:
store = BlobStore()
keys = await store.list_keys("u1", "tasks")
assert keys == []
@pytest.mark.asyncio
async def test_upload_uses_sse_s3_encryption(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
# Verify S3 metadata was set — check via head_object
with patch("app.storage.blob_store.settings") as mock_settings:
mock_settings.S3_BUCKET = _BUCKET
mock_settings.S3_REGION = _REGION
mock_settings.AWS_ACCESS_KEY_ID = "testing"
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
client = boto3.client("s3", region_name=_REGION)
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
assert response.get("ServerSideEncryption") == "AES256"
@pytest.mark.asyncio
async def test_upload_stores_checksum_in_metadata(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
client = boto3.client("s3", region_name=_REGION)
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
assert response["Metadata"]["checksum"] == _CHECKSUM
# ── _blob_to_vector helper ────────────────────────────────────────────
class TestBlobToVector:
def test_returns_32_floats(self) -> None:
v = _blob_to_vector(b"test")
assert len(v) == 32
def test_all_values_in_range(self) -> None:
v = _blob_to_vector(b"test")
assert all(-1.0 <= x <= 1.0 for x in v)
def test_deterministic(self) -> None:
assert _blob_to_vector(b"same") == _blob_to_vector(b"same")
def test_different_blobs_different_vectors(self) -> None:
assert _blob_to_vector(b"aaa") != _blob_to_vector(b"bbb")
# ── TestVectorStorePinecone ───────────────────────────────────────────
class TestVectorStorePinecone:
def _store(self) -> VectorStore:
store = VectorStore()
store._use_pinecone = lambda: True # type: ignore[method-assign]
return store
@pytest.mark.asyncio
async def test_upsert_calls_index_upsert(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
items = [VectorItem(id="v1", blob=b"enc-blob", checksum=hashlib.sha256(b"enc-blob").hexdigest())]
await store.upsert("u1", items)
mock_index.upsert.assert_called_once()
call_kwargs = mock_index.upsert.call_args[1]
assert call_kwargs.get("namespace") == "u1"
@pytest.mark.asyncio
async def test_upsert_encodes_blob_as_base64_in_metadata(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
items = [VectorItem(id="v1", blob=b"secret", checksum=hashlib.sha256(b"secret").hexdigest())]
await store.upsert("u1", items)
vectors_arg = mock_index.upsert.call_args[1]["vectors"]
assert vectors_arg[0]["metadata"]["blob"] == base64.b64encode(b"secret").decode()
@pytest.mark.asyncio
async def test_search_calls_index_query(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
await store.search("u1", b"query-blob", top_k=5)
mock_index.query.assert_called_once()
query_kwargs = mock_index.query.call_args[1]
assert query_kwargs.get("namespace") == "u1"
assert query_kwargs.get("top_k") == 5
assert query_kwargs.get("include_metadata") is True
@pytest.mark.asyncio
async def test_search_returns_vector_search_results(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
results = await store.search("u1", b"query", top_k=10)
assert len(results) == 1
assert isinstance(results[0], VectorSearchResult)
assert results[0].id == "v1"
assert results[0].score == 0.95
assert results[0].blob == b"result-blob"
@pytest.mark.asyncio
async def test_search_uses_derived_query_vector(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
await store.search("u1", b"query-blob", top_k=3)
expected_vector = _blob_to_vector(b"query-blob")
actual_vector = mock_index.query.call_args[1].get("vector")
assert actual_vector == expected_vector
@pytest.mark.asyncio
async def test_delete_calls_index_delete(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
await store.delete("u1", ["v1", "v2"])
mock_index.delete.assert_called_once()
delete_kwargs = mock_index.delete.call_args[1]
assert delete_kwargs.get("namespace") == "u1"
assert set(delete_kwargs.get("ids", [])) == {"v1", "v2"}
# ── TestVectorStoreQdrant ─────────────────────────────────────────────
class TestVectorStoreQdrant:
def _store(self) -> VectorStore:
store = VectorStore()
store._use_pinecone = lambda: False # type: ignore[method-assign]
return store
def _qdrant_mock(self) -> MagicMock:
mock_hit = MagicMock()
mock_hit.id = "v1"
mock_hit.score = 0.88
mock_hit.payload = {
"blob": base64.b64encode(b"qdrant-result").decode(),
"user_id": "u1",
}
mock_client = MagicMock()
mock_client.search.return_value = [mock_hit]
return mock_client
@pytest.mark.asyncio
async def test_upsert_calls_client_upsert(self) -> None:
mock_client = MagicMock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
await store.upsert("u1", items)
mock_client.upsert.assert_called_once()
@pytest.mark.asyncio
async def test_upsert_uses_correct_collection(self) -> None:
mock_client = MagicMock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
await store.upsert("u1", items)
call_kwargs = mock_client.upsert.call_args[1]
assert call_kwargs["collection_name"] == "adiuva_vectors"
@pytest.mark.asyncio
async def test_search_calls_client_search(self) -> None:
mock_client = self._qdrant_mock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
await store.search("u1", b"query", top_k=5)
mock_client.search.assert_called_once()
@pytest.mark.asyncio
async def test_search_passes_limit(self) -> None:
mock_client = self._qdrant_mock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
await store.search("u1", b"query", top_k=7)
call_kwargs = mock_client.search.call_args[1]
assert call_kwargs.get("limit") == 7
@pytest.mark.asyncio
async def test_search_returns_vector_search_results(self) -> None:
mock_client = self._qdrant_mock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
results = await store.search("u1", b"query", top_k=5)
assert len(results) == 1
assert isinstance(results[0], VectorSearchResult)
assert results[0].id == "v1"
assert results[0].score == 0.88
assert results[0].blob == b"qdrant-result"
@pytest.mark.asyncio
async def test_delete_calls_client_delete(self) -> None:
mock_client = MagicMock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
await store.delete("u1", ["v1", "v2"])
mock_client.delete.assert_called_once()
@pytest.mark.asyncio
async def test_delete_uses_correct_collection(self) -> None:
mock_client = MagicMock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
await store.delete("u1", ["v1"])
call_kwargs = mock_client.delete.call_args[1]
assert call_kwargs["collection_name"] == "adiuva_vectors"