From c8ef7b119b12f8384991d7ada1df5f04665a51ca Mon Sep 17 00:00:00 2001 From: roberto Date: Mon, 2 Mar 2026 15:36:09 +0100 Subject: [PATCH] Refactor tests for execution plan and add comprehensive storage tests - Updated `TestModuleSingletons` in `test_execution_plan.py` to reflect new agent templates and playbook names. - Changed assertions in playbook tests to match updated templates and agents. - Introduced `test_storage.py` to cover the storage layer, including encryption, BlobStore, and VectorStore functionalities. - Added tests for S3 interactions, ensuring upload, download, delete, and list operations work as expected. - Implemented mock tests for Pinecone and Qdrant vector stores to validate upsert, search, and delete operations. --- app/agents/__init__.py | 4 +- app/agents/analytics_agent.py | 80 ----- app/agents/calendar_agent.py | 76 ----- app/agents/checkpoint_agent.py | 122 +++++++ app/agents/email_agent.py | 77 ----- app/agents/note_agent.py | 123 +++++++ app/agents/project_agent.py | 158 +++++++++ app/agents/task_agent.py | 181 +++++++++-- app/api/deps.py | 46 +++ app/api/routes/auth.py | 118 +++++++ app/config/settings.py | 5 + app/core/execution_plan.py | 54 +-- app/schemas.py | 73 +++++ app/storage/__init__.py | 1 + app/storage/blob_store.py | 105 ++++++ app/storage/encryption.py | 32 ++ app/storage/vector_store.py | 205 ++++++++++++ requirements.txt | 3 + tests/test_agents.py | 579 +++++++++++++++++++++++---------- tests/test_execution_plan.py | 22 +- tests/test_storage.py | 385 ++++++++++++++++++++++ 21 files changed, 1980 insertions(+), 469 deletions(-) delete mode 100644 app/agents/analytics_agent.py delete mode 100644 app/agents/calendar_agent.py create mode 100644 app/agents/checkpoint_agent.py delete mode 100644 app/agents/email_agent.py create mode 100644 app/agents/note_agent.py create mode 100644 app/agents/project_agent.py create mode 100644 app/api/deps.py create mode 100644 app/api/routes/auth.py create mode 100644 app/storage/__init__.py create mode 100644 app/storage/blob_store.py create mode 100644 app/storage/encryption.py create mode 100644 app/storage/vector_store.py create mode 100644 tests/test_storage.py diff --git a/app/agents/__init__.py b/app/agents/__init__.py index a2c8d21..a511527 100644 --- a/app/agents/__init__.py +++ b/app/agents/__init__.py @@ -1,5 +1,5 @@ """Import all agent modules to trigger @registry.register decorators.""" -from app.agents import analytics_agent, calendar_agent, email_agent, task_agent +from app.agents import checkpoint_agent, note_agent, project_agent, task_agent -__all__ = ["analytics_agent", "calendar_agent", "email_agent", "task_agent"] +__all__ = ["checkpoint_agent", "note_agent", "project_agent", "task_agent"] diff --git a/app/agents/analytics_agent.py b/app/agents/analytics_agent.py deleted file mode 100644 index 1b8e99f..0000000 --- a/app/agents/analytics_agent.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Analytics agent — metrics, reports, and trend analysis.""" - -from __future__ import annotations - -import json -from typing import Any - -from langchain_core.messages import HumanMessage, SystemMessage -from langchain_core.tools import tool -from langchain_openai import ChatOpenAI - -from app.config.settings import settings -from app.core.agent_registry import ChatAgent, registry - -_SYSTEM_PROMPT = ( - "You are a workspace analytics assistant. Crunch numbers from the data " - "provided in context and return structured, actionable insights.\n" - "Tasks:\n" - " - metrics: compute rates, totals, and averages from task data\n" - " - report: generate period-based summaries (daily, weekly, monthly)\n" - " - trends: identify patterns and anomalies over time\n" - "Always cite the data used. Do not fabricate figures." -) - - -@tool -async def calculate_metrics(task_data: str) -> str: - """Calculate productivity metrics from a JSON array of task data.""" - return json.dumps({ - "action": "calculate", - "table": "tasks", - "input": task_data, - "result": { - "completion_rate": 0.0, - "overdue_count": 0, - "avg_priority": "medium", - }, - }) - - -@tool -async def generate_report(period: str, data: str) -> str: - """Generate a structured report for a time period (e.g. 'last_7_days', 'last_month').""" - return json.dumps({ - "action": "report", - "period": period, - "input": data, - }) - - -@tool -async def trend_analysis(data_points: str) -> str: - """Analyse trends in a JSON array of time-series data points.""" - return json.dumps({ - "action": "trend", - "input": data_points, - "result": {"trend": "stable", "anomalies": []}, - }) - - -@registry.register -class AnalyticsAgent(ChatAgent): - def get_name(self) -> str: - return "analytics_agent" - - def get_description(self) -> str: - return "Workspace analytics: metrics, reports, trends" - - def get_tools(self) -> list[Any]: - return [calculate_metrics, generate_report, trend_analysis] - - async def handle(self, query: str, context: dict[str, Any]) -> str: - llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY) - messages = [ - SystemMessage(content=_SYSTEM_PROMPT), - HumanMessage( - content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}" - ), - ] - return await self._tool_loop(llm, messages, self.get_tools()) diff --git a/app/agents/calendar_agent.py b/app/agents/calendar_agent.py deleted file mode 100644 index f546e15..0000000 --- a/app/agents/calendar_agent.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Calendar agent — events, conflict detection, and scheduling.""" - -from __future__ import annotations - -import json -from typing import Any - -from langchain_core.messages import HumanMessage, SystemMessage -from langchain_core.tools import tool -from langchain_openai import ChatOpenAI - -from app.config.settings import settings -from app.core.agent_registry import ChatAgent, registry - -_SYSTEM_PROMPT = ( - "You are a calendar management assistant. Help the user manage events, " - "detect scheduling conflicts, and suggest reschedules.\n" - "Rules:\n" - " - Work exclusively with event metadata provided in context\n" - " - Never store or reference raw calendar data\n" - " - date_range format: ISO 8601 interval, e.g. '2024-01-01/2024-01-07'\n" - " - Always confirm the date/time scope of any operation" -) - - -@tool -async def list_events(date_range: str) -> str: - """List calendar events in a date range (ISO 8601 interval, e.g. '2024-01-01/2024-01-07').""" - return json.dumps({ - "action": "list", - "table": "events", - "filters": {"date_range": date_range}, - }) - - -@tool -async def detect_conflicts(events: str) -> str: - """Detect scheduling conflicts in a JSON array of event metadata objects.""" - return json.dumps({ - "action": "analyse", - "table": "events", - "input": events, - "result": "conflicts_detected", - }) - - -@tool -async def suggest_reschedule(conflict: str) -> str: - """Suggest a reschedule for a conflicting event. Pass the conflict as a JSON string.""" - return json.dumps({ - "action": "suggest_reschedule", - "table": "events", - "input": conflict, - }) - - -@registry.register -class CalendarAgent(ChatAgent): - def get_name(self) -> str: - return "calendar_agent" - - def get_description(self) -> str: - return "Calendar management: events, conflicts, scheduling" - - def get_tools(self) -> list[Any]: - return [list_events, detect_conflicts, suggest_reschedule] - - async def handle(self, query: str, context: dict[str, Any]) -> str: - llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY) - messages = [ - SystemMessage(content=_SYSTEM_PROMPT), - HumanMessage( - content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}" - ), - ] - return await self._tool_loop(llm, messages, self.get_tools()) diff --git a/app/agents/checkpoint_agent.py b/app/agents/checkpoint_agent.py new file mode 100644 index 0000000..9410aab --- /dev/null +++ b/app/agents/checkpoint_agent.py @@ -0,0 +1,122 @@ +"""Checkpoint agent — project milestone management (list, create, update, delete).""" + +from __future__ import annotations + +import json +from typing import Any + +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.tools import tool +from langchain_openai import ChatOpenAI + +from app.config.settings import settings +from app.core.agent_registry import ChatAgent, registry + +_SYSTEM_PROMPT = ( + "You are a project checkpoint assistant. Checkpoints are milestone dates that\n" + "track progress on a project — they are not calendar events.\n\n" + "Rules:\n" + " - project_id is REQUIRED for every create; confirm with the user if unknown\n" + " - date is a Unix timestamp in milliseconds; convert human-readable dates\n" + " - is_ai_suggested: 1 when proactively proposing a checkpoint, 0 otherwise\n" + " - is_approved: 0 until the user explicitly confirms; then 1\n" + " - For update_checkpoint, use -1 for integer fields you do not want to change\n" + " - Listing without a project_id returns all checkpoints across projects\n" + " - Always echo the title and formatted date in your confirmation." +) + + +@tool +async def list_checkpoints(project_id: str = "") -> str: + """List checkpoints. Provide project_id to scope to a specific project.""" + return json.dumps({ + "action": "list", + "table": "checkpoints", + "filters": {"projectId": project_id or None}, + }) + + +@tool +async def create_checkpoint( + project_id: str, + title: str, + date: int, + is_ai_suggested: int = 0, + is_approved: int = 0, +) -> str: + """Create a project checkpoint (milestone). + project_id: REQUIRED UUID of the parent project + title: descriptive name for the milestone + date: Unix timestamp in milliseconds + is_ai_suggested: 1 if proactively suggested, 0 if user-requested + is_approved: 0 until the user confirms + """ + return json.dumps({ + "action": "create_record", + "table": "checkpoints", + "data": { + "projectId": project_id, + "title": title, + "date": date, + "isAiSuggested": is_ai_suggested, + "isApproved": is_approved, + }, + }) + + +@tool +async def update_checkpoint( + checkpoint_id: str, + title: str = "", + date: int = -1, + is_approved: int = -1, +) -> str: + """Update a checkpoint. Only pass fields that should change. + checkpoint_id: UUID of the checkpoint (required) + date: -1 means unchanged; any other value sets the new date (ms timestamp) + is_approved: -1 means unchanged; 0 or 1 sets the approval state + """ + updates: dict[str, Any] = {} + if title: + updates["title"] = title + if date != -1: + updates["date"] = date + if is_approved != -1: + updates["isApproved"] = is_approved + return json.dumps({ + "action": "update_record", + "table": "checkpoints", + "data": {"id": checkpoint_id, "updates": updates}, + }) + + +@tool +async def delete_checkpoint(checkpoint_id: str) -> str: + """Delete a checkpoint permanently by its UUID.""" + return json.dumps({ + "action": "delete_record", + "table": "checkpoints", + "data": {"id": checkpoint_id}, + }) + + +@registry.register +class CheckpointAgent(ChatAgent): + def get_name(self) -> str: + return "checkpoint_agent" + + def get_description(self) -> str: + return "Manages project checkpoints (milestones): list, create, update, delete" + + def get_tools(self) -> list[Any]: + return [list_checkpoints, create_checkpoint, update_checkpoint, delete_checkpoint] + + async def handle(self, query: str, context: dict[str, Any]) -> str: + llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY) + messages = [ + SystemMessage(content=_SYSTEM_PROMPT), + HumanMessage( + content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}" + ), + ] + return await self._tool_loop(llm, messages, self.get_tools()) diff --git a/app/agents/email_agent.py b/app/agents/email_agent.py deleted file mode 100644 index 656f88a..0000000 --- a/app/agents/email_agent.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Email agent — classify, extract action items, draft responses.""" - -from __future__ import annotations - -import json -from typing import Any - -from langchain_core.messages import HumanMessage, SystemMessage -from langchain_core.tools import tool -from langchain_openai import ChatOpenAI - -from app.config.settings import settings -from app.core.agent_registry import ChatAgent, registry - -_SYSTEM_PROMPT = ( - "You are an email analysis assistant. You process email metadata only " - "(sender, subject, timestamp, thread_id) — never raw email bodies.\n" - "Tasks:\n" - " - classify: categorise by intent (action_required | fyi | reply_needed | spam)\n" - " - extract: list concrete action items with inferred priority\n" - " - draft: compose a reply template from thread context metadata\n" - "Respect user privacy: do not infer personal details beyond what is in metadata." -) - - -@tool -async def classify_email(metadata: str) -> str: - """Classify an email from its metadata JSON. Returns category and confidence score.""" - return json.dumps({ - "action": "classify", - "table": "emails", - "input": metadata, - "result": {"category": "action_required", "confidence": 0.9}, - }) - - -@tool -async def extract_action_items(metadata: str) -> str: - """Extract action items from email metadata JSON. Returns a list of task descriptions.""" - return json.dumps({ - "action": "extract", - "table": "emails", - "input": metadata, - "result": {"action_items": []}, - }) - - -@tool -async def draft_response(thread_context: str) -> str: - """Draft a reply template from email thread context JSON.""" - return json.dumps({ - "action": "draft", - "table": "emails", - "input": thread_context, - }) - - -@registry.register -class EmailAgent(ChatAgent): - def get_name(self) -> str: - return "email_agent" - - def get_description(self) -> str: - return "Email analysis: classify, extract actions, draft responses" - - def get_tools(self) -> list[Any]: - return [classify_email, extract_action_items, draft_response] - - async def handle(self, query: str, context: dict[str, Any]) -> str: - llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY) - messages = [ - SystemMessage(content=_SYSTEM_PROMPT), - HumanMessage( - content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}" - ), - ] - return await self._tool_loop(llm, messages, self.get_tools()) diff --git a/app/agents/note_agent.py b/app/agents/note_agent.py new file mode 100644 index 0000000..65898cc --- /dev/null +++ b/app/agents/note_agent.py @@ -0,0 +1,123 @@ +"""Note agent — Markdown note management (list, get, create, update, delete).""" + +from __future__ import annotations + +import json +from typing import Any + +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.tools import tool +from langchain_openai import ChatOpenAI + +from app.config.settings import settings +from app.core.agent_registry import ChatAgent, registry + +_SYSTEM_PROMPT = ( + "You are a note-taking assistant. You help users create, retrieve, update,\n" + "and delete Markdown notes in their workspace.\n\n" + "Rules:\n" + " - content is always Markdown; preserve formatting when updating\n" + " - project_id is optional; link a note to a project when mentioned\n" + " - When updating, call get_note first if you need to read existing content\n" + " before appending or replacing sections\n" + " - list_notes without project_id returns all notes; scope with project_id\n" + " when the user is working within a specific project\n" + " - Do not fabricate note content — reflect what the user provides or what\n" + " is already in the note (retrieved via get_note)." +) + + +@tool +async def list_notes(project_id: str = "") -> str: + """List notes, optionally scoped to a project by project_id.""" + return json.dumps({ + "action": "list", + "table": "notes", + "filters": {"projectId": project_id or None}, + }) + + +@tool +async def get_note(note_id: str) -> str: + """Fetch a single note by its UUID to read its full Markdown content.""" + return json.dumps({ + "action": "get", + "table": "notes", + "data": {"id": note_id}, + }) + + +@tool +async def create_note( + title: str, + content: str, + project_id: str = "", +) -> str: + """Create a new note. + title: note heading (required) + content: Markdown body text (required) + project_id: optional UUID linking this note to a project + """ + return json.dumps({ + "action": "create_record", + "table": "notes", + "data": { + "title": title, + "content": content, + "projectId": project_id or None, + }, + }) + + +@tool +async def update_note( + note_id: str, + title: str = "", + content: str = "", +) -> str: + """Update an existing note. Only pass fields that should change. + note_id: UUID of the note (required) + If you need to preserve existing content, call get_note first. + """ + updates: dict[str, Any] = {} + if title: + updates["title"] = title + if content: + updates["content"] = content + return json.dumps({ + "action": "update_record", + "table": "notes", + "data": {"id": note_id, "updates": updates}, + }) + + +@tool +async def delete_note(note_id: str) -> str: + """Delete a note permanently by its UUID.""" + return json.dumps({ + "action": "delete_record", + "table": "notes", + "data": {"id": note_id}, + }) + + +@registry.register +class NoteAgent(ChatAgent): + def get_name(self) -> str: + return "note_agent" + + def get_description(self) -> str: + return "Manages notes: list, get, create, update, delete" + + def get_tools(self) -> list[Any]: + return [list_notes, get_note, create_note, update_note, delete_note] + + async def handle(self, query: str, context: dict[str, Any]) -> str: + llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY) + messages = [ + SystemMessage(content=_SYSTEM_PROMPT), + HumanMessage( + content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}" + ), + ] + return await self._tool_loop(llm, messages, self.get_tools()) diff --git a/app/agents/project_agent.py b/app/agents/project_agent.py new file mode 100644 index 0000000..1054386 --- /dev/null +++ b/app/agents/project_agent.py @@ -0,0 +1,158 @@ +"""Project agent — full lifecycle management (list, get, create, update, archive, delete).""" + +from __future__ import annotations + +import json +from typing import Any + +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.tools import tool +from langchain_openai import ChatOpenAI + +from app.config.settings import settings +from app.core.agent_registry import ChatAgent, registry + +_SYSTEM_PROMPT = ( + "You are a project management assistant. You help users create, find,\n" + "update, and archive projects in their workspace.\n\n" + "Rules:\n" + " - status must be one of: active, archived\n" + " - client_id is optional; link to a client only when explicitly mentioned\n" + " - ai_summary is populated only when the user asks for a project summary;\n" + " derive it from context data — do not fabricate content\n" + " - Use list_projects for scoped queries; list_all_projects only when the\n" + " user wants a complete cross-client view including archived projects\n" + " - get_project requires a project UUID; resolve the ID first by calling\n" + " list_projects if you only have a project name\n" + " - Prefer archiving (update_project status=archived) over deletion;\n" + " only call delete_project when the user explicitly confirms deletion." +) + + +@tool +async def list_projects( + client_id: str = "", + include_archived: int = 0, +) -> str: + """List projects, optionally filtered by client_id. + include_archived: 1 to include archived projects, 0 for active only (default). + """ + return json.dumps({ + "action": "list", + "table": "projects", + "filters": { + "clientId": client_id or None, + "includeArchived": bool(include_archived), + }, + }) + + +@tool +async def list_all_projects() -> str: + """List every project regardless of client or status. + Use only when the user wants a complete cross-client overview. + """ + return json.dumps({ + "action": "list_all", + "table": "projects", + }) + + +@tool +async def get_project(project_id: str) -> str: + """Fetch a single project by its UUID.""" + return json.dumps({ + "action": "get", + "table": "projects", + "data": {"id": project_id}, + }) + + +@tool +async def create_project( + name: str, + client_id: str = "", +) -> str: + """Create a new project. + name: human-readable project name (required) + client_id: optional UUID of the owning client + """ + return json.dumps({ + "action": "create_record", + "table": "projects", + "data": { + "name": name, + "clientId": client_id or None, + }, + }) + + +@tool +async def update_project( + project_id: str, + name: str = "", + client_id: str = "", + status: str = "", + ai_summary: str = "", +) -> str: + """Update a project. Only pass fields that should change. + project_id: UUID of the project (required) + status: active | archived + ai_summary: AI-generated summary text (populate only when explicitly requested) + """ + updates: dict[str, Any] = {} + if name: + updates["name"] = name + if client_id: + updates["clientId"] = client_id + if status: + updates["status"] = status + if ai_summary: + updates["aiSummary"] = ai_summary + return json.dumps({ + "action": "update_record", + "table": "projects", + "data": {"id": project_id, "updates": updates}, + }) + + +@tool +async def delete_project(project_id: str) -> str: + """Permanently delete a project and orphan its tasks. + IMPORTANT: prefer update_project(status='archived') unless the user + has explicitly confirmed they want permanent deletion. + """ + return json.dumps({ + "action": "delete_record", + "table": "projects", + "data": {"id": project_id}, + }) + + +@registry.register +class ProjectAgent(ChatAgent): + def get_name(self) -> str: + return "project_agent" + + def get_description(self) -> str: + return "Manages projects: list, get, create, update, archive, delete" + + def get_tools(self) -> list[Any]: + return [ + list_projects, + list_all_projects, + get_project, + create_project, + update_project, + delete_project, + ] + + async def handle(self, query: str, context: dict[str, Any]) -> str: + llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY) + messages = [ + SystemMessage(content=_SYSTEM_PROMPT), + HumanMessage( + content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}" + ), + ] + return await self._tool_loop(llm, messages, self.get_tools()) diff --git a/app/agents/task_agent.py b/app/agents/task_agent.py index 2beab66..df1d3c0 100644 --- a/app/agents/task_agent.py +++ b/app/agents/task_agent.py @@ -1,4 +1,4 @@ -"""Task agent — create, update, list, and suggest tasks.""" +"""Task agent — full CRUD for tasks and task comments.""" from __future__ import annotations @@ -13,40 +13,121 @@ from app.config.settings import settings from app.core.agent_registry import ChatAgent, registry _SYSTEM_PROMPT = ( - "You are a task management assistant (PM-oriented). Help the user create, " - "update, list, and suggest tasks.\n" + "You are a task management assistant for a project workspace.\n" + "You create, update, list, and track tasks and their comments.\n\n" "Rules:\n" - " - priority must be one of: low, medium, high, urgent\n" - " - infer priority from context clues (deadlines, urgency language, dependencies)\n" - " - due_date as ISO 8601 string when provided\n" - " - context fields beyond user_profile are optional; use them when present\n" - "Use the available tools to act, then confirm what was done in plain language." + " - status must be one of: todo, in_progress, done\n" + " - priority must be one of: high, medium, low\n" + " - due_date is a Unix timestamp in milliseconds; convert human dates\n" + " - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n" + " - project_id is optional; link to a project when the user mentions one\n" + " - is_ai_suggested: 1 only when proactively proposing a task the user\n" + " did not explicitly request; 0 otherwise\n" + " - is_approved defaults to 0; set to 1 only when the user confirms\n" + " - Use list_tasks_due_today for 'what's due today' queries\n" + " - For update_task, use -1 for integer fields you do not want to change\n" + " - Always confirm the action in plain, user-friendly language." ) +# ── Task tools ──────────────────────────────────────────────────────── + + +@tool +async def list_tasks( + project_id: str = "", + status: str = "", + search: str = "", + order_by: str = "", +) -> str: + """List tasks, optionally filtered by project_id, status (todo|in_progress|done), + a search string, or an order_by field name (dueDate|priority|createdAt).""" + return json.dumps({ + "action": "list", + "table": "tasks", + "filters": { + "projectId": project_id or None, + "status": status or None, + "search": search or None, + "orderBy": order_by or None, + }, + }) + + @tool async def create_task( title: str, description: str = "", + status: str = "todo", priority: str = "medium", - due_date: str = "", + assignees: str = "[]", + due_date: int = 0, + project_id: str = "", + is_ai_suggested: int = 0, + is_approved: int = 0, ) -> str: - """Create a new task. priority: low | medium | high | urgent. due_date: ISO 8601.""" + """Create a new task. + title: task title (required) + description: optional details + status: todo | in_progress | done (default: todo) + priority: high | medium | low (default: medium) + assignees: JSON-encoded array of assignee names, e.g. '["Alice"]' + due_date: Unix timestamp in milliseconds; 0 means no due date + project_id: optional UUID of the parent project + is_ai_suggested: 1 if proactively suggested, 0 if user-requested + is_approved: 0 until the user confirms; 1 when confirmed + """ return json.dumps({ "action": "create_record", "table": "tasks", "data": { "title": title, - "description": description, + "description": description or None, + "status": status, "priority": priority, - "due_date": due_date, + "assignee": assignees, + "dueDate": due_date or None, + "projectId": project_id or None, + "isAiSuggested": is_ai_suggested, + "isApproved": is_approved, }, }) @tool -async def update_task(task_id: str, updates: str) -> str: - """Update fields on an existing task. Pass updates as a JSON string, e.g. '{"priority":"high"}'.""" +async def update_task( + task_id: str, + title: str = "", + description: str = "", + status: str = "", + priority: str = "", + assignees: str = "", + due_date: int = -1, + project_id: str = "", + is_approved: int = -1, +) -> str: + """Update fields on an existing task. Only pass fields you want to change. + task_id: the task's UUID (required) + due_date: -1 means unchanged; 0 clears the due date; any positive value sets it + is_approved: -1 means unchanged; 0 or 1 sets the value + """ + updates: dict[str, Any] = {} + if title: + updates["title"] = title + if description: + updates["description"] = description + if status: + updates["status"] = status + if priority: + updates["priority"] = priority + if assignees: + updates["assignee"] = assignees + if due_date != -1: + updates["dueDate"] = due_date or None + if project_id: + updates["projectId"] = project_id + if is_approved != -1: + updates["isApproved"] = is_approved return json.dumps({ "action": "update_record", "table": "tasks", @@ -55,35 +136,87 @@ async def update_task(task_id: str, updates: str) -> str: @tool -async def list_tasks(status: str = "", priority: str = "") -> str: - """List tasks. Optionally filter by status (open|done|archived) or priority level.""" +async def delete_task(task_id: str) -> str: + """Delete a task permanently by its UUID.""" return json.dumps({ - "action": "list", + "action": "delete_record", "table": "tasks", - "filters": {"status": status, "priority": priority}, + "data": {"id": task_id}, }) @tool -async def suggest_tasks(context: str) -> str: - """Suggest new tasks based on notes or free-form context text.""" +async def list_tasks_due_today() -> str: + """List all tasks whose due date falls on today's date.""" return json.dumps({ - "action": "suggest", + "action": "list_due_today", "table": "tasks", - "context": context, }) +# ── Task comment tools ──────────────────────────────────────────────── + + +@tool +async def list_task_comments(task_id: str) -> str: + """List all comments on a task by its UUID.""" + return json.dumps({ + "action": "list", + "table": "taskComments", + "filters": {"taskId": task_id}, + }) + + +@tool +async def add_task_comment(task_id: str, author: str, content: str) -> str: + """Add a comment to a task. + task_id: UUID of the task to comment on + author: name or ID of the comment author + content: comment text + """ + return json.dumps({ + "action": "create_record", + "table": "taskComments", + "data": { + "taskId": task_id, + "author": author, + "content": content, + }, + }) + + +@tool +async def delete_task_comment(comment_id: str) -> str: + """Delete a task comment by its UUID.""" + return json.dumps({ + "action": "delete_record", + "table": "taskComments", + "data": {"id": comment_id}, + }) + + +# ── Agent ───────────────────────────────────────────────────────────── + + @registry.register class TaskAgent(ChatAgent): def get_name(self) -> str: return "task_agent" def get_description(self) -> str: - return "Manages tasks: create, update, list, suggest" + return "Manages tasks and comments: list, create, update, delete, due-today, comments" def get_tools(self) -> list[Any]: - return [create_task, update_task, list_tasks, suggest_tasks] + return [ + list_tasks, + create_task, + update_task, + delete_task, + list_tasks_due_today, + list_task_comments, + add_task_comment, + delete_task_comment, + ] async def handle(self, query: str, context: dict[str, Any]) -> str: llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY) diff --git a/app/api/deps.py b/app/api/deps.py new file mode 100644 index 0000000..a8fb393 --- /dev/null +++ b/app/api/deps.py @@ -0,0 +1,46 @@ +"""Shared FastAPI dependencies. + +``get_current_user`` decodes the Bearer JWT and returns a ``UserProfile``. +Step 9 will layer rate-limiting and sanitization middleware on top of this. +Step 12 will add a DB look-up to fetch the live tier from PostgreSQL. +""" + +from __future__ import annotations + +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jose import JWTError, jwt + +from app.config.settings import settings +from app.schemas import BillingTier, UserProfile + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") + + +async def get_current_user( + token: str = Depends(oauth2_scheme), +) -> UserProfile: + """Validate a Bearer JWT and return the authenticated user. + + Raises ``HTTP 401`` on any invalid or expired token. + The tier embedded in the JWT is used for feature-gating until Step 12 + adds a live DB lookup. + """ + credentials_exc = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode( + token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] + ) + user_id: str | None = payload.get("sub") + email: str | None = payload.get("email") + tier: str = payload.get("tier", "free") + if not user_id or not email: + raise credentials_exc + except JWTError: + raise credentials_exc + + return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type] diff --git a/app/api/routes/auth.py b/app/api/routes/auth.py new file mode 100644 index 0000000..64c0bf5 --- /dev/null +++ b/app/api/routes/auth.py @@ -0,0 +1,118 @@ +"""Auth routes: register, login, refresh, me. + +Users and refresh tokens are kept in an in-memory dict until Step 12 +migrates them to PostgreSQL. +""" + +from __future__ import annotations + +import time +import uuid +from typing import Any + +import bcrypt +from fastapi import APIRouter, Depends, HTTPException, status +from jose import jwt +from pydantic import BaseModel + +from app.api.deps import get_current_user +from app.config.settings import settings +from app.schemas import AuthTokens, UserProfile + +router = APIRouter(prefix="/auth", tags=["auth"]) + +# ── In-memory stores (replaced by PostgreSQL in Step 12) ───────────── +_users: dict[str, dict[str, Any]] = {} # email → user record +_refresh_tokens: dict[str, str] = {} # plain token → user_id + + +# ── Internal helpers ───────────────────────────────────────────────── + +def _hash_password(password: str) -> str: + return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() + + +def _verify_password(password: str, hashed: str) -> bool: + return bcrypt.checkpw(password.encode(), hashed.encode()) + + +def _make_tokens(user_id: str, email: str, tier: str) -> AuthTokens: + now = int(time.time()) + access_exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60 + access_payload = { + "sub": user_id, + "email": email, + "tier": tier, + "exp": access_exp, + "iat": now, + } + access_token = jwt.encode( + access_payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM + ) + refresh_token = str(uuid.uuid4()) + _refresh_tokens[refresh_token] = user_id + return AuthTokens( + access_token=access_token, + refresh_token=refresh_token, + expires_at=access_exp * 1000, # milliseconds for client + ) + + +# ── Request bodies ──────────────────────────────────────────────────── + +class _RegisterRequest(BaseModel): + email: str + password: str + + +class _LoginRequest(BaseModel): + email: str + password: str + + +class _RefreshRequest(BaseModel): + refresh_token: str + + +# ── Routes ──────────────────────────────────────────────────────────── + +@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED) +async def register(body: _RegisterRequest) -> AuthTokens: + """Create a new account and return JWT tokens.""" + if body.email in _users: + raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered") + user_id = str(uuid.uuid4()) + _users[body.email] = { + "id": user_id, + "email": body.email, + "password_hash": _hash_password(body.password), + "tier": "free", + } + return _make_tokens(user_id, body.email, "free") + + +@router.post("/login", response_model=AuthTokens) +async def login(body: _LoginRequest) -> AuthTokens: + """Validate credentials and return JWT tokens.""" + user = _users.get(body.email) + if not user or not _verify_password(body.password, user["password_hash"]): + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials") + return _make_tokens(user["id"], user["email"], user["tier"]) + + +@router.post("/refresh", response_model=AuthTokens) +async def refresh(body: _RefreshRequest) -> AuthTokens: + """Rotate a refresh token and return a new token pair.""" + user_id = _refresh_tokens.pop(body.refresh_token, None) + if user_id is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token") + user = next((u for u in _users.values() if u["id"] == user_id), None) + if user is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found") + return _make_tokens(user["id"], user["email"], user["tier"]) + + +@router.get("/me", response_model=UserProfile) +async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile: + """Return the profile for the authenticated user.""" + return current_user diff --git a/app/config/settings.py b/app/config/settings.py index 6a154f8..c9d7042 100644 --- a/app/config/settings.py +++ b/app/config/settings.py @@ -17,6 +17,11 @@ class Settings(BaseSettings): AWS_ACCESS_KEY_ID: str = "" AWS_SECRET_ACCESS_KEY: str = "" + PINECONE_API_KEY: str = "" + PINECONE_INDEX: str = "adiuva" + QDRANT_URL: str = "" + QDRANT_API_KEY: str = "" + OPENAI_API_KEY: str = "" CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"] diff --git a/app/core/execution_plan.py b/app/core/execution_plan.py index a6edd3a..b763937 100644 --- a/app/core/execution_plan.py +++ b/app/core/execution_plan.py @@ -156,29 +156,33 @@ def _register_builtin_templates() -> None: _tpls: dict[str, str] = { "tpl_task_agent_default": ( "You are a task management assistant. Help the user create, update, " - "and prioritize tasks based on their message and context." + "list, and track tasks. Use correct status values (todo, in_progress, " + "done) and priority values (high, medium, low) from the workspace model." ), - "tpl_calendar_agent_default": ( - "You are a calendar assistant. Help manage events, detect scheduling " - "conflicts, and suggest improvements based on the provided context." + "tpl_checkpoint_agent_default": ( + "You are a project checkpoint assistant. Help the user create and manage " + "milestone checkpoints on their projects. Every checkpoint requires a " + "project_id and a date expressed as a Unix timestamp in milliseconds." ), - "tpl_email_agent_default": ( - "You are an email analysis assistant. Classify emails, extract action " - "items, and draft responses using only the metadata provided." + "tpl_project_agent_default": ( + "You are a project management assistant. Help the user create, find, " + "update, and archive projects. Projects have a name, an optional client, " + "and a status of either active or archived." ), - "tpl_analytics_agent_default": ( - "You are a workspace analytics assistant. Calculate metrics, generate " - "reports, and surface trends from the data provided in context." + "tpl_note_agent_default": ( + "You are a note-taking assistant. Help the user create, retrieve, update, " + "and delete Markdown notes. Notes can optionally be linked to a project." ), - "tpl_email_extract_action_items": ( - "Extract all action items from the provided email metadata. " - "Return a structured list of tasks, each with a title, inferred " - "priority, and suggested due date where possible." + "tpl_task_extract_from_project": ( + "Extract all actionable tasks from the provided project context. " + "Return a structured list of tasks, each with a title, inferred priority " + "(high, medium, or low), suggested status (todo), and a due_date in " + "milliseconds where a deadline can be inferred." ), - "tpl_analytics_weekly_summary": ( - "Generate a weekly performance summary from the provided analytics " - "data. Include task completion rate, overdue item count, top " - "priorities for the coming week, and notable trends." + "tpl_note_weekly_summary": ( + "Generate a weekly project summary note from the provided workspace data. " + "Include: tasks completed this week, tasks due soon, active projects, " + "and upcoming checkpoints. Format the output as clean Markdown." ), } for tid, text in _tpls.items(): @@ -189,20 +193,20 @@ def _load_playbooks() -> None: """Pre-build and cache the built-in playbooks.""" playbooks: list[tuple[str, ExecutionPlan]] = [ ( - "create_task_from_email", - ExecutionPlanBuilder("email_agent") + "create_tasks_from_project", + ExecutionPlanBuilder("project_agent") .add_llm_step( - "tpl_email_extract_action_items", - {"source": "email_metadata"}, + "tpl_task_extract_from_project", + {"source": "project_context"}, ) .add_data_step("create_record", data_from_step=0) .build(), ), ( - "generate_weekly_report", - ExecutionPlanBuilder("analytics_agent") + "generate_weekly_note", + ExecutionPlanBuilder("note_agent") .add_llm_step( - "tpl_analytics_weekly_summary", + "tpl_note_weekly_summary", {"period": "last_7_days"}, ) .add_data_step("create_record", data_from_step=0) diff --git a/app/schemas.py b/app/schemas.py index 0737824..ab291b8 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -82,3 +82,76 @@ class BackupMetadata(BaseModel): timestamp: int checksum: str chunk_count: int + + +# ── Cloud Storage (E2E encrypted blobs) ────────────────────────────── + +class StorageRecord(BaseModel): + id: str + user_id: str + table: str + blob: bytes + checksum: str + created_at: int + updated_at: int + + +class StorageRecordCreate(BaseModel): + table: str + blob: bytes + checksum: str + + +class StorageRecordUpdate(BaseModel): + blob: bytes + checksum: str + + +# ── Cloud Vector Store (E2E encrypted vectors) ──────────────────────── + +class VectorItem(BaseModel): + id: str + blob: bytes # encrypted vector + metadata — backend never decrypts + checksum: str + + +class VectorUpsertRequest(BaseModel): + vectors: list[VectorItem] + + +class VectorSearchRequest(BaseModel): + query_blob: bytes # encrypted query — backend never decrypts + top_k: int = 10 + + +class VectorSearchResult(BaseModel): + id: str + score: float + blob: bytes + + +class VectorSearchResponse(BaseModel): + results: list[VectorSearchResult] + + +# ── Plugin Marketplace ──────────────────────────────────────────────── + +class PluginManifest(BaseModel): + id: str + name: str + description: str + version: str + author: str + permissions: list[str] + category: str + price_cents: int = 0 + + +class PluginListResponse(BaseModel): + plugins: list[PluginManifest] + total: int + page: int + + +class PluginInstallRequest(BaseModel): + plugin_id: str diff --git a/app/storage/__init__.py b/app/storage/__init__.py new file mode 100644 index 0000000..9223ba7 --- /dev/null +++ b/app/storage/__init__.py @@ -0,0 +1 @@ +"""Cloud storage layer — E2E encrypted blobs and vectors.""" diff --git a/app/storage/blob_store.py b/app/storage/blob_store.py new file mode 100644 index 0000000..48ee190 --- /dev/null +++ b/app/storage/blob_store.py @@ -0,0 +1,105 @@ +"""S3-backed store for E2E-encrypted blobs. + +Keys are structured as ``{user_id}/{table}/{record_id}``. +The backend never inspects blob content — it stores and retrieves opaque bytes. +""" + +from __future__ import annotations + +from typing import Any + +import boto3 +from botocore.exceptions import ClientError + +from app.config.settings import settings + + +class BlobStore: + """Thin wrapper around boto3 S3. + + All blobs must be E2E encrypted by the client before upload. + The backend adds SSE-S3 as an extra layer of at-rest encryption + but cannot decrypt the inner client-side payload. + """ + + def _client(self) -> Any: + return boto3.client( + "s3", + region_name=settings.S3_REGION, + aws_access_key_id=settings.AWS_ACCESS_KEY_ID, + aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, + ) + + @staticmethod + def _key(user_id: str, table: str, record_id: str) -> str: + return f"{user_id}/{table}/{record_id}" + + async def upload( + self, + user_id: str, + table: str, + record_id: str, + blob: bytes, + checksum: str, + ) -> str: + """Store *blob* in S3 and return the S3 key. + + Args: + user_id: Owner of the blob (used as key prefix). + table: Logical table name (e.g. ``"tasks"``). + record_id: Record UUID. + blob: Raw bytes (pre-encrypted by client). + checksum: SHA-256 hex digest supplied by the client; stored as + object metadata for download-time verification. + + Returns: + The S3 key under which the blob was stored. + """ + key = self._key(user_id, table, record_id) + self._client().put_object( + Bucket=settings.S3_BUCKET, + Key=key, + Body=blob, + ServerSideEncryption="AES256", # SSE-S3 at rest + Metadata={"checksum": checksum}, + ) + return key + + async def download(self, user_id: str, s3_key: str) -> bytes: + """Retrieve the blob stored at *s3_key*. + + *user_id* is retained in the signature so higher-level code can + enforce ownership without re-parsing the key. + + Raises: + ``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the + object does not exist. + """ + response = self._client().get_object( + Bucket=settings.S3_BUCKET, + Key=s3_key, + ) + return response["Body"].read() + + async def delete(self, user_id: str, s3_key: str) -> None: + """Delete the object at *s3_key*. + + S3 ``delete_object`` is idempotent — it succeeds even if the key does + not exist. + """ + self._client().delete_object( + Bucket=settings.S3_BUCKET, + Key=s3_key, + ) + + async def list_keys(self, user_id: str, table: str) -> list[str]: + """Return all S3 keys for a given user + table combination. + + Uses the prefix ``{user_id}/{table}/`` to scope the listing. + """ + prefix = f"{user_id}/{table}/" + response = self._client().list_objects_v2( + Bucket=settings.S3_BUCKET, + Prefix=prefix, + ) + return [obj["Key"] for obj in response.get("Contents", [])] diff --git a/app/storage/encryption.py b/app/storage/encryption.py new file mode 100644 index 0000000..2dfefa2 --- /dev/null +++ b/app/storage/encryption.py @@ -0,0 +1,32 @@ +"""Integrity verification only — the backend NEVER decrypts user data.""" + +from __future__ import annotations + +import hashlib +import hmac + +from fastapi import HTTPException + + +def verify_checksum(blob: bytes, checksum: str) -> bool: + """Return ``True`` if SHA-256(blob) matches *checksum*. + + Uses ``hmac.compare_digest`` for constant-time comparison to prevent + timing-based side-channel attacks. + """ + computed = hashlib.sha256(blob).hexdigest() + return hmac.compare_digest(computed, checksum) + + +def reject_if_tampered(blob: bytes, checksum: str) -> None: + """Raise ``HTTP 400`` if the blob does not match its checksum. + + Call this before storing or forwarding any client-provided blob. + The backend never holds decryption keys — this check only verifies + that the opaque bytes arrived intact. + """ + if not verify_checksum(blob, checksum): + raise HTTPException( + status_code=400, + detail="Checksum mismatch: blob integrity check failed", + ) diff --git a/app/storage/vector_store.py b/app/storage/vector_store.py new file mode 100644 index 0000000..a2d5c32 --- /dev/null +++ b/app/storage/vector_store.py @@ -0,0 +1,205 @@ +"""Cloud vector store — wraps Pinecone (default) or Qdrant. + +Vectors are pre-encrypted blobs from the client. The backend stores them +alongside a deterministic 32-dim float representation derived from the blob's +SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this +is a known trade-off documented in the backend plan. + +Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by +``user_id`` payload field on a shared collection. +""" + +from __future__ import annotations + +import base64 +import hashlib +from typing import Any + +from pinecone import Pinecone +from qdrant_client import QdrantClient +from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct + +from app.config.settings import settings +from app.schemas import VectorItem, VectorSearchResult + +_QDRANT_COLLECTION = "adiuva_vectors" + + +def _blob_to_vector(blob: bytes) -> list[float]: + """Derive a 32-dim float vector from *blob* for storage purposes only. + + Uses SHA-256 to produce a deterministic 32-byte fingerprint, then + normalises each byte to the range [-1.0, 1.0]. This vector carries no + semantic meaning on encrypted data. + """ + return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()] + + +class VectorStore: + """Thin wrapper around Pinecone or Qdrant. + + The backend to use is selected at runtime: + - Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty. + - Qdrant: otherwise (requires ``settings.QDRANT_URL``). + """ + + def _use_pinecone(self) -> bool: + return bool(settings.PINECONE_API_KEY) + + # ── Pinecone helpers ────────────────────────────────────────────── + + def _pinecone_index(self) -> Any: + pc = Pinecone(api_key=settings.PINECONE_API_KEY) + return pc.Index(settings.PINECONE_INDEX) + + # ── Qdrant helpers ──────────────────────────────────────────────── + + def _qdrant_client(self) -> Any: + return QdrantClient( + url=settings.QDRANT_URL, + api_key=settings.QDRANT_API_KEY or None, + ) + + # ── Public API ──────────────────────────────────────────────────── + + async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None: + """Store encrypted vectors in the backend. + + Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload + so it can be returned verbatim during search. + + Args: + user_id: Used as Pinecone namespace or Qdrant payload field. + vectors: List of encrypted vector items from the client. + """ + if self._use_pinecone(): + await self._pinecone_upsert(user_id, vectors) + else: + await self._qdrant_upsert(user_id, vectors) + + async def search( + self, + user_id: str, + query_blob: bytes, + top_k: int, + ) -> list[VectorSearchResult]: + """Query the vector store and return encrypted result blobs. + + The query vector is derived from *query_blob* using the same + deterministic mapping as upsert. + + Args: + user_id: Scopes the search to this user's namespace. + query_blob: Encrypted query from the client. + top_k: Maximum number of results to return. + + Returns: + List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``. + """ + if self._use_pinecone(): + return await self._pinecone_search(user_id, query_blob, top_k) + return await self._qdrant_search(user_id, query_blob, top_k) + + async def delete(self, user_id: str, vector_ids: list[str]) -> None: + """Remove vectors by ID, scoped to *user_id*. + + Args: + user_id: Namespace / payload filter to prevent cross-user deletion. + vector_ids: List of vector IDs to remove. + """ + if self._use_pinecone(): + await self._pinecone_delete(user_id, vector_ids) + else: + await self._qdrant_delete(user_id, vector_ids) + + # ── Pinecone implementation ─────────────────────────────────────── + + async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None: + index = self._pinecone_index() + records = [ + { + "id": v.id, + "values": _blob_to_vector(v.blob), + "metadata": { + "blob": base64.b64encode(v.blob).decode(), + "checksum": v.checksum, + "user_id": user_id, + }, + } + for v in vectors + ] + index.upsert(vectors=records, namespace=user_id) + + async def _pinecone_search( + self, user_id: str, query_blob: bytes, top_k: int + ) -> list[VectorSearchResult]: + index = self._pinecone_index() + query_vector = _blob_to_vector(query_blob) + response = index.query( + vector=query_vector, + top_k=top_k, + namespace=user_id, + include_metadata=True, + ) + results: list[VectorSearchResult] = [] + for match in response.get("matches", []): + blob_bytes = base64.b64decode(match["metadata"]["blob"]) + results.append( + VectorSearchResult( + id=match["id"], + score=match["score"], + blob=blob_bytes, + ) + ) + return results + + async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None: + index = self._pinecone_index() + index.delete(ids=vector_ids, namespace=user_id) + + # ── Qdrant implementation ───────────────────────────────────────── + + async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None: + client = self._qdrant_client() + points = [ + PointStruct( + id=v.id, + vector=_blob_to_vector(v.blob), + payload={ + "blob": base64.b64encode(v.blob).decode(), + "checksum": v.checksum, + "user_id": user_id, + }, + ) + for v in vectors + ] + client.upsert(collection_name=_QDRANT_COLLECTION, points=points) + + async def _qdrant_search( + self, user_id: str, query_blob: bytes, top_k: int + ) -> list[VectorSearchResult]: + client = self._qdrant_client() + query_vector = _blob_to_vector(query_blob) + hits = client.search( + collection_name=_QDRANT_COLLECTION, + query_vector=query_vector, + query_filter=Filter( + must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))] + ), + limit=top_k, + ) + return [ + VectorSearchResult( + id=str(hit.id), + score=hit.score, + blob=base64.b64decode(hit.payload["blob"]), + ) + for hit in hits + ] + + async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None: + client = self._qdrant_client() + client.delete( + collection_name=_QDRANT_COLLECTION, + points_selector=PointIdsList(points=vector_ids), + ) diff --git a/requirements.txt b/requirements.txt index a7590c1..f2465ff 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,6 @@ httpx>=0.28.0 websockets>=14.0 pytest>=8.0.0 pytest-asyncio>=0.24.0 +moto[s3]>=5.0.0 +pinecone>=5.0.0 +qdrant-client>=1.7.0 diff --git a/tests/test_agents.py b/tests/test_agents.py index ac8bba2..ebbcf86 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -1,4 +1,4 @@ -"""Unit tests for all four chat agents with mocked LLM.""" +"""Unit tests for the four domain-specific chat agents with mocked LLM.""" from __future__ import annotations @@ -9,9 +9,9 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest import app.agents # noqa: F401 — triggers @registry.register decorators -from app.agents.analytics_agent import AnalyticsAgent -from app.agents.calendar_agent import CalendarAgent -from app.agents.email_agent import EmailAgent +from app.agents.checkpoint_agent import CheckpointAgent +from app.agents.note_agent import NoteAgent +from app.agents.project_agent import ProjectAgent from app.agents.task_agent import TaskAgent from app.core.agent_registry import registry @@ -59,15 +59,15 @@ def _mock_llm_with_tool_call( class TestAgentRegistration: def test_all_agents_registered(self) -> None: names = {a["name"] for a in registry.list_agents()} - assert {"task_agent", "calendar_agent", "email_agent", "analytics_agent"}.issubset( - names - ) + assert { + "task_agent", "checkpoint_agent", "project_agent", "note_agent" + }.issubset(names) def test_registry_returns_correct_types(self) -> None: assert isinstance(registry.get("task_agent"), TaskAgent) - assert isinstance(registry.get("calendar_agent"), CalendarAgent) - assert isinstance(registry.get("email_agent"), EmailAgent) - assert isinstance(registry.get("analytics_agent"), AnalyticsAgent) + assert isinstance(registry.get("checkpoint_agent"), CheckpointAgent) + assert isinstance(registry.get("project_agent"), ProjectAgent) + assert isinstance(registry.get("note_agent"), NoteAgent) def test_descriptions_present(self) -> None: for agent_info in registry.list_agents(): @@ -82,14 +82,23 @@ class TestTaskAgent: assert TaskAgent().get_name() == "task_agent" def test_description(self) -> None: - assert TaskAgent().get_description() == "Manages tasks: create, update, list, suggest" + assert TaskAgent().get_description() == "Manages tasks and comments: list, create, update, delete, due-today, comments" def test_get_tools_count(self) -> None: - assert len(TaskAgent().get_tools()) == 4 + assert len(TaskAgent().get_tools()) == 8 def test_tool_names(self) -> None: names = {t.name for t in TaskAgent().get_tools()} - assert names == {"create_task", "update_task", "list_tasks", "suggest_tasks"} + assert names == { + "list_tasks", + "create_task", + "update_task", + "delete_task", + "list_tasks_due_today", + "list_task_comments", + "add_task_comment", + "delete_task_comment", + } @pytest.mark.asyncio async def test_handle_returns_string(self) -> None: @@ -111,10 +120,10 @@ class TestTaskAgent: mock_cls.return_value = _mock_llm_with_tool_call( "create_task", {"title": "Buy groceries", "priority": "low"}, - "Task 'Buy groceries' created with low priority.", + "Task 'Buy groceries' created.", ) result = await TaskAgent().handle("add a grocery task", {}) - assert result == "Task 'Buy groceries' created with low priority." + assert result == "Task 'Buy groceries' created." @pytest.mark.asyncio async def test_handle_accepts_empty_context(self) -> None: @@ -123,20 +132,11 @@ class TestTaskAgent: result = await TaskAgent().handle("help", {}) assert isinstance(result, str) - @pytest.mark.asyncio - async def test_handle_accepts_partial_context(self) -> None: - with patch("app.agents.task_agent.ChatOpenAI") as mock_cls: - mock_cls.return_value = _mock_llm("Done.") - result = await TaskAgent().handle("list tasks", {"user_profile": {"id": "u1"}}) - assert isinstance(result, str) - @pytest.mark.asyncio async def test_handle_accepts_rich_context(self) -> None: context = { "user_profile": {"id": "u1", "tier": "pro"}, "recent_tasks": [{"id": "t1", "title": "Old task"}], - "relevant_documents": ["doc1"], - "extra_plugin_data": {"batch_id": "b1"}, } with patch("app.agents.task_agent.ChatOpenAI") as mock_cls: mock_cls.return_value = _mock_llm("Tasks listed.") @@ -146,244 +146,475 @@ class TestTaskAgent: class TestTaskAgentTools: @pytest.mark.asyncio - async def test_create_task_returns_valid_json(self) -> None: + async def test_list_tasks_defaults(self) -> None: + from app.agents.task_agent import list_tasks + result = await list_tasks.ainvoke({}) + data = json.loads(result) + assert data["action"] == "list" + assert data["table"] == "tasks" + + @pytest.mark.asyncio + async def test_list_tasks_with_status_filter(self) -> None: + from app.agents.task_agent import list_tasks + result = await list_tasks.ainvoke({"status": "done"}) + data = json.loads(result) + assert data["filters"]["status"] == "done" + + @pytest.mark.asyncio + async def test_create_task_defaults(self) -> None: from app.agents.task_agent import create_task - result = await create_task.ainvoke({"title": "Test task", "priority": "high"}) + result = await create_task.ainvoke({"title": "Test task"}) data = json.loads(result) assert data["action"] == "create_record" assert data["table"] == "tasks" assert data["data"]["title"] == "Test task" - assert data["data"]["priority"] == "high" + assert data["data"]["status"] == "todo" + assert data["data"]["priority"] == "medium" @pytest.mark.asyncio - async def test_update_task_returns_valid_json(self) -> None: + async def test_create_task_with_all_fields(self) -> None: + from app.agents.task_agent import create_task + result = await create_task.ainvoke({ + "title": "Deploy", + "priority": "high", + "status": "in_progress", + "project_id": "p1", + "is_ai_suggested": 1, + }) + data = json.loads(result) + assert data["data"]["priority"] == "high" + assert data["data"]["status"] == "in_progress" + assert data["data"]["projectId"] == "p1" + assert data["data"]["isAiSuggested"] == 1 + + @pytest.mark.asyncio + async def test_update_task_with_status(self) -> None: from app.agents.task_agent import update_task - result = await update_task.ainvoke( - {"task_id": "t1", "updates": '{"priority": "urgent"}'} - ) + result = await update_task.ainvoke({"task_id": "t1", "status": "done"}) data = json.loads(result) assert data["action"] == "update_record" assert data["data"]["id"] == "t1" + assert data["data"]["updates"]["status"] == "done" @pytest.mark.asyncio - async def test_list_tasks_returns_valid_json(self) -> None: - from app.agents.task_agent import list_tasks - result = await list_tasks.ainvoke({"status": "open"}) + async def test_update_task_empty_updates(self) -> None: + from app.agents.task_agent import update_task + result = await update_task.ainvoke({"task_id": "t1"}) data = json.loads(result) - assert data["action"] == "list" + assert data["data"]["updates"] == {} + + @pytest.mark.asyncio + async def test_delete_task(self) -> None: + from app.agents.task_agent import delete_task + result = await delete_task.ainvoke({"task_id": "t1"}) + data = json.loads(result) + assert data["action"] == "delete_record" + assert data["table"] == "tasks" + assert data["data"]["id"] == "t1" + + @pytest.mark.asyncio + async def test_list_tasks_due_today(self) -> None: + from app.agents.task_agent import list_tasks_due_today + result = await list_tasks_due_today.ainvoke({}) + data = json.loads(result) + assert data["action"] == "list_due_today" assert data["table"] == "tasks" @pytest.mark.asyncio - async def test_suggest_tasks_returns_valid_json(self) -> None: - from app.agents.task_agent import suggest_tasks - result = await suggest_tasks.ainvoke({"context": "lots of meetings this week"}) - data = json.loads(result) - assert data["action"] == "suggest" - - -# ── CalendarAgent ───────────────────────────────────────────────────── - - -class TestCalendarAgent: - def test_name(self) -> None: - assert CalendarAgent().get_name() == "calendar_agent" - - def test_description(self) -> None: - assert CalendarAgent().get_description() == "Calendar management: events, conflicts, scheduling" - - def test_get_tools_count(self) -> None: - assert len(CalendarAgent().get_tools()) == 3 - - def test_tool_names(self) -> None: - names = {t.name for t in CalendarAgent().get_tools()} - assert names == {"list_events", "detect_conflicts", "suggest_reschedule"} - - @pytest.mark.asyncio - async def test_handle_no_tool_calls(self) -> None: - with patch("app.agents.calendar_agent.ChatOpenAI") as mock_cls: - mock_cls.return_value = _mock_llm("No conflicts found.") - result = await CalendarAgent().handle("check my schedule", {}) - assert result == "No conflicts found." - - @pytest.mark.asyncio - async def test_handle_with_list_events_tool_call(self) -> None: - with patch("app.agents.calendar_agent.ChatOpenAI") as mock_cls: - mock_cls.return_value = _mock_llm_with_tool_call( - "list_events", - {"date_range": "2024-01-01/2024-01-07"}, - "You have 3 events next week.", - ) - result = await CalendarAgent().handle("what events do I have?", {}) - assert result == "You have 3 events next week." - - @pytest.mark.asyncio - async def test_handle_accepts_empty_context(self) -> None: - with patch("app.agents.calendar_agent.ChatOpenAI") as mock_cls: - mock_cls.return_value = _mock_llm("Done.") - result = await CalendarAgent().handle("reschedule meeting", {}) - assert isinstance(result, str) - - -class TestCalendarAgentTools: - @pytest.mark.asyncio - async def test_list_events_returns_valid_json(self) -> None: - from app.agents.calendar_agent import list_events - result = await list_events.ainvoke({"date_range": "2024-01-01/2024-01-07"}) + async def test_list_task_comments(self) -> None: + from app.agents.task_agent import list_task_comments + result = await list_task_comments.ainvoke({"task_id": "t1"}) data = json.loads(result) assert data["action"] == "list" - assert data["table"] == "events" - assert data["filters"]["date_range"] == "2024-01-01/2024-01-07" + assert data["table"] == "taskComments" + assert data["filters"]["taskId"] == "t1" @pytest.mark.asyncio - async def test_detect_conflicts_returns_valid_json(self) -> None: - from app.agents.calendar_agent import detect_conflicts - result = await detect_conflicts.ainvoke({"events": "[]"}) + async def test_add_task_comment(self) -> None: + from app.agents.task_agent import add_task_comment + result = await add_task_comment.ainvoke({ + "task_id": "t1", + "author": "Alice", + "content": "Looks good!", + }) data = json.loads(result) - assert data["action"] == "analyse" + assert data["action"] == "create_record" + assert data["table"] == "taskComments" + assert data["data"]["taskId"] == "t1" + assert data["data"]["author"] == "Alice" + assert data["data"]["content"] == "Looks good!" @pytest.mark.asyncio - async def test_suggest_reschedule_returns_valid_json(self) -> None: - from app.agents.calendar_agent import suggest_reschedule - result = await suggest_reschedule.ainvoke({"conflict": '{"event": "standup"}'}) + async def test_delete_task_comment(self) -> None: + from app.agents.task_agent import delete_task_comment + result = await delete_task_comment.ainvoke({"comment_id": "c1"}) data = json.loads(result) - assert data["action"] == "suggest_reschedule" + assert data["action"] == "delete_record" + assert data["table"] == "taskComments" + assert data["data"]["id"] == "c1" -# ── EmailAgent ──────────────────────────────────────────────────────── +# ── CheckpointAgent ─────────────────────────────────────────────────── -class TestEmailAgent: +class TestCheckpointAgent: def test_name(self) -> None: - assert EmailAgent().get_name() == "email_agent" + assert CheckpointAgent().get_name() == "checkpoint_agent" def test_description(self) -> None: - assert EmailAgent().get_description() == "Email analysis: classify, extract actions, draft responses" + assert CheckpointAgent().get_description() == "Manages project checkpoints (milestones): list, create, update, delete" def test_get_tools_count(self) -> None: - assert len(EmailAgent().get_tools()) == 3 + assert len(CheckpointAgent().get_tools()) == 4 def test_tool_names(self) -> None: - names = {t.name for t in EmailAgent().get_tools()} - assert names == {"classify_email", "extract_action_items", "draft_response"} + names = {t.name for t in CheckpointAgent().get_tools()} + assert names == {"list_checkpoints", "create_checkpoint", "update_checkpoint", "delete_checkpoint"} @pytest.mark.asyncio async def test_handle_no_tool_calls(self) -> None: - with patch("app.agents.email_agent.ChatOpenAI") as mock_cls: - mock_cls.return_value = _mock_llm("Email classified as action_required.") - result = await EmailAgent().handle("classify this email", {}) - assert result == "Email classified as action_required." + with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls: + mock_cls.return_value = _mock_llm("No checkpoints found.") + result = await CheckpointAgent().handle("list checkpoints", {}) + assert result == "No checkpoints found." @pytest.mark.asyncio - async def test_handle_with_classify_tool_call(self) -> None: - with patch("app.agents.email_agent.ChatOpenAI") as mock_cls: + async def test_handle_with_create_tool_call(self) -> None: + with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls: mock_cls.return_value = _mock_llm_with_tool_call( - "classify_email", - {"metadata": '{"subject": "URGENT: action needed"}'}, - "This email requires immediate action.", + "create_checkpoint", + {"project_id": "p1", "title": "MVP Launch", "date": 1700000000000}, + "Checkpoint 'MVP Launch' created.", ) - result = await EmailAgent().handle("what is this email about?", {}) - assert result == "This email requires immediate action." + result = await CheckpointAgent().handle("add MVP checkpoint", {}) + assert result == "Checkpoint 'MVP Launch' created." @pytest.mark.asyncio async def test_handle_accepts_empty_context(self) -> None: - with patch("app.agents.email_agent.ChatOpenAI") as mock_cls: + with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls: mock_cls.return_value = _mock_llm("Done.") - result = await EmailAgent().handle("draft a reply", {}) + result = await CheckpointAgent().handle("show milestones", {}) assert isinstance(result, str) -class TestEmailAgentTools: +class TestCheckpointAgentTools: @pytest.mark.asyncio - async def test_classify_email_returns_valid_json(self) -> None: - from app.agents.email_agent import classify_email - result = await classify_email.ainvoke({"metadata": '{"subject": "Meeting"}' }) + async def test_list_checkpoints_no_project(self) -> None: + from app.agents.checkpoint_agent import list_checkpoints + result = await list_checkpoints.ainvoke({}) data = json.loads(result) - assert data["action"] == "classify" - assert "result" in data - assert "category" in data["result"] + assert data["action"] == "list" + assert data["table"] == "checkpoints" + assert data["filters"]["projectId"] is None @pytest.mark.asyncio - async def test_extract_action_items_returns_valid_json(self) -> None: - from app.agents.email_agent import extract_action_items - result = await extract_action_items.ainvoke({"metadata": '{"subject": "Follow up"}'}) + async def test_list_checkpoints_with_project(self) -> None: + from app.agents.checkpoint_agent import list_checkpoints + result = await list_checkpoints.ainvoke({"project_id": "p1"}) data = json.loads(result) - assert data["action"] == "extract" - assert "action_items" in data["result"] + assert data["filters"]["projectId"] == "p1" @pytest.mark.asyncio - async def test_draft_response_returns_valid_json(self) -> None: - from app.agents.email_agent import draft_response - result = await draft_response.ainvoke({"thread_context": '{"thread_id": "t1"}'}) + async def test_create_checkpoint(self) -> None: + from app.agents.checkpoint_agent import create_checkpoint + result = await create_checkpoint.ainvoke({ + "project_id": "p1", + "title": "Beta release", + "date": 1700000000000, + }) data = json.loads(result) - assert data["action"] == "draft" + assert data["action"] == "create_record" + assert data["table"] == "checkpoints" + assert data["data"]["projectId"] == "p1" + assert data["data"]["title"] == "Beta release" + assert data["data"]["date"] == 1700000000000 + + @pytest.mark.asyncio + async def test_create_checkpoint_ai_suggested(self) -> None: + from app.agents.checkpoint_agent import create_checkpoint + result = await create_checkpoint.ainvoke({ + "project_id": "p1", + "title": "Review", + "date": 1700000000000, + "is_ai_suggested": 1, + }) + data = json.loads(result) + assert data["data"]["isAiSuggested"] == 1 + assert data["data"]["isApproved"] == 0 + + @pytest.mark.asyncio + async def test_update_checkpoint_approve(self) -> None: + from app.agents.checkpoint_agent import update_checkpoint + result = await update_checkpoint.ainvoke({ + "checkpoint_id": "c1", + "is_approved": 1, + }) + data = json.loads(result) + assert data["action"] == "update_record" + assert data["data"]["id"] == "c1" + assert data["data"]["updates"]["isApproved"] == 1 + + @pytest.mark.asyncio + async def test_update_checkpoint_empty_updates(self) -> None: + from app.agents.checkpoint_agent import update_checkpoint + result = await update_checkpoint.ainvoke({"checkpoint_id": "c1"}) + data = json.loads(result) + assert data["data"]["updates"] == {} + + @pytest.mark.asyncio + async def test_delete_checkpoint(self) -> None: + from app.agents.checkpoint_agent import delete_checkpoint + result = await delete_checkpoint.ainvoke({"checkpoint_id": "c1"}) + data = json.loads(result) + assert data["action"] == "delete_record" + assert data["table"] == "checkpoints" + assert data["data"]["id"] == "c1" -# ── AnalyticsAgent ──────────────────────────────────────────────────── +# ── ProjectAgent ────────────────────────────────────────────────────── -class TestAnalyticsAgent: +class TestProjectAgent: def test_name(self) -> None: - assert AnalyticsAgent().get_name() == "analytics_agent" + assert ProjectAgent().get_name() == "project_agent" def test_description(self) -> None: - assert AnalyticsAgent().get_description() == "Workspace analytics: metrics, reports, trends" + assert ProjectAgent().get_description() == "Manages projects: list, get, create, update, archive, delete" def test_get_tools_count(self) -> None: - assert len(AnalyticsAgent().get_tools()) == 3 + assert len(ProjectAgent().get_tools()) == 6 def test_tool_names(self) -> None: - names = {t.name for t in AnalyticsAgent().get_tools()} - assert names == {"calculate_metrics", "generate_report", "trend_analysis"} + names = {t.name for t in ProjectAgent().get_tools()} + assert names == { + "list_projects", + "list_all_projects", + "get_project", + "create_project", + "update_project", + "delete_project", + } @pytest.mark.asyncio async def test_handle_no_tool_calls(self) -> None: - with patch("app.agents.analytics_agent.ChatOpenAI") as mock_cls: - mock_cls.return_value = _mock_llm("Completion rate is 78%.") - result = await AnalyticsAgent().handle("show my metrics", {}) - assert result == "Completion rate is 78%." + with patch("app.agents.project_agent.ChatOpenAI") as mock_cls: + mock_cls.return_value = _mock_llm("Project Alpha is active.") + result = await ProjectAgent().handle("show my projects", {}) + assert result == "Project Alpha is active." @pytest.mark.asyncio - async def test_handle_with_generate_report_tool_call(self) -> None: - with patch("app.agents.analytics_agent.ChatOpenAI") as mock_cls: + async def test_handle_with_create_project_tool_call(self) -> None: + with patch("app.agents.project_agent.ChatOpenAI") as mock_cls: mock_cls.return_value = _mock_llm_with_tool_call( - "generate_report", - {"period": "last_7_days", "data": "[]"}, - "Weekly report: 12 tasks completed, 2 overdue.", + "create_project", + {"name": "Pippo"}, + "Project 'Pippo' created.", ) - result = await AnalyticsAgent().handle("weekly report", {}) - assert result == "Weekly report: 12 tasks completed, 2 overdue." + result = await ProjectAgent().handle("create project Pippo", {}) + assert result == "Project 'Pippo' created." @pytest.mark.asyncio async def test_handle_accepts_empty_context(self) -> None: - with patch("app.agents.analytics_agent.ChatOpenAI") as mock_cls: + with patch("app.agents.project_agent.ChatOpenAI") as mock_cls: mock_cls.return_value = _mock_llm("Done.") - result = await AnalyticsAgent().handle("analyse trends", {}) + result = await ProjectAgent().handle("archive old project", {}) assert isinstance(result, str) -class TestAnalyticsAgentTools: +class TestProjectAgentTools: @pytest.mark.asyncio - async def test_calculate_metrics_returns_valid_json(self) -> None: - from app.agents.analytics_agent import calculate_metrics - result = await calculate_metrics.ainvoke({"task_data": "[]"}) + async def test_list_projects_defaults(self) -> None: + from app.agents.project_agent import list_projects + result = await list_projects.ainvoke({}) data = json.loads(result) - assert data["action"] == "calculate" - assert "result" in data - assert "completion_rate" in data["result"] + assert data["action"] == "list" + assert data["table"] == "projects" + assert data["filters"]["includeArchived"] is False @pytest.mark.asyncio - async def test_generate_report_returns_valid_json(self) -> None: - from app.agents.analytics_agent import generate_report - result = await generate_report.ainvoke({"period": "last_7_days", "data": "[]"}) + async def test_list_projects_include_archived(self) -> None: + from app.agents.project_agent import list_projects + result = await list_projects.ainvoke({"include_archived": 1}) data = json.loads(result) - assert data["action"] == "report" - assert data["period"] == "last_7_days" + assert data["filters"]["includeArchived"] is True @pytest.mark.asyncio - async def test_trend_analysis_returns_valid_json(self) -> None: - from app.agents.analytics_agent import trend_analysis - result = await trend_analysis.ainvoke({"data_points": "[]"}) + async def test_list_all_projects(self) -> None: + from app.agents.project_agent import list_all_projects + result = await list_all_projects.ainvoke({}) data = json.loads(result) - assert data["action"] == "trend" - assert "result" in data - assert "anomalies" in data["result"] + assert data["action"] == "list_all" + assert data["table"] == "projects" + + @pytest.mark.asyncio + async def test_get_project(self) -> None: + from app.agents.project_agent import get_project + result = await get_project.ainvoke({"project_id": "p1"}) + data = json.loads(result) + assert data["action"] == "get" + assert data["table"] == "projects" + assert data["data"]["id"] == "p1" + + @pytest.mark.asyncio + async def test_create_project_name_only(self) -> None: + from app.agents.project_agent import create_project + result = await create_project.ainvoke({"name": "Alpha"}) + data = json.loads(result) + assert data["action"] == "create_record" + assert data["data"]["name"] == "Alpha" + assert data["data"]["clientId"] is None + + @pytest.mark.asyncio + async def test_create_project_with_client(self) -> None: + from app.agents.project_agent import create_project + result = await create_project.ainvoke({"name": "Beta", "client_id": "cl1"}) + data = json.loads(result) + assert data["data"]["clientId"] == "cl1" + + @pytest.mark.asyncio + async def test_update_project_archive(self) -> None: + from app.agents.project_agent import update_project + result = await update_project.ainvoke({"project_id": "p1", "status": "archived"}) + data = json.loads(result) + assert data["action"] == "update_record" + assert data["data"]["id"] == "p1" + assert data["data"]["updates"]["status"] == "archived" + + @pytest.mark.asyncio + async def test_update_project_empty_updates(self) -> None: + from app.agents.project_agent import update_project + result = await update_project.ainvoke({"project_id": "p1"}) + data = json.loads(result) + assert data["data"]["updates"] == {} + + @pytest.mark.asyncio + async def test_delete_project(self) -> None: + from app.agents.project_agent import delete_project + result = await delete_project.ainvoke({"project_id": "p1"}) + data = json.loads(result) + assert data["action"] == "delete_record" + assert data["data"]["id"] == "p1" + + +# ── NoteAgent ───────────────────────────────────────────────────────── + + +class TestNoteAgent: + def test_name(self) -> None: + assert NoteAgent().get_name() == "note_agent" + + def test_description(self) -> None: + assert NoteAgent().get_description() == "Manages notes: list, get, create, update, delete" + + def test_get_tools_count(self) -> None: + assert len(NoteAgent().get_tools()) == 5 + + def test_tool_names(self) -> None: + names = {t.name for t in NoteAgent().get_tools()} + assert names == {"list_notes", "get_note", "create_note", "update_note", "delete_note"} + + @pytest.mark.asyncio + async def test_handle_no_tool_calls(self) -> None: + with patch("app.agents.note_agent.ChatOpenAI") as mock_cls: + mock_cls.return_value = _mock_llm("Note created.") + result = await NoteAgent().handle("create a note", {}) + assert result == "Note created." + + @pytest.mark.asyncio + async def test_handle_with_create_note_tool_call(self) -> None: + with patch("app.agents.note_agent.ChatOpenAI") as mock_cls: + mock_cls.return_value = _mock_llm_with_tool_call( + "create_note", + {"title": "Daily log", "content": "# Today\nAll good."}, + "Note 'Daily log' created.", + ) + result = await NoteAgent().handle("log today's progress", {}) + assert result == "Note 'Daily log' created." + + @pytest.mark.asyncio + async def test_handle_accepts_empty_context(self) -> None: + with patch("app.agents.note_agent.ChatOpenAI") as mock_cls: + mock_cls.return_value = _mock_llm("Done.") + result = await NoteAgent().handle("show notes", {}) + assert isinstance(result, str) + + +class TestNoteAgentTools: + @pytest.mark.asyncio + async def test_list_notes_no_project(self) -> None: + from app.agents.note_agent import list_notes + result = await list_notes.ainvoke({}) + data = json.loads(result) + assert data["action"] == "list" + assert data["table"] == "notes" + assert data["filters"]["projectId"] is None + + @pytest.mark.asyncio + async def test_list_notes_with_project(self) -> None: + from app.agents.note_agent import list_notes + result = await list_notes.ainvoke({"project_id": "p1"}) + data = json.loads(result) + assert data["filters"]["projectId"] == "p1" + + @pytest.mark.asyncio + async def test_get_note(self) -> None: + from app.agents.note_agent import get_note + result = await get_note.ainvoke({"note_id": "n1"}) + data = json.loads(result) + assert data["action"] == "get" + assert data["table"] == "notes" + assert data["data"]["id"] == "n1" + + @pytest.mark.asyncio + async def test_create_note_minimal(self) -> None: + from app.agents.note_agent import create_note + result = await create_note.ainvoke({ + "title": "Daily log", + "content": "# Today\nAll good.", + }) + data = json.loads(result) + assert data["action"] == "create_record" + assert data["table"] == "notes" + assert data["data"]["title"] == "Daily log" + assert data["data"]["content"] == "# Today\nAll good." + assert data["data"]["projectId"] is None + + @pytest.mark.asyncio + async def test_create_note_with_project(self) -> None: + from app.agents.note_agent import create_note + result = await create_note.ainvoke({ + "title": "Sprint notes", + "content": "## Sprint 1", + "project_id": "p1", + }) + data = json.loads(result) + assert data["data"]["projectId"] == "p1" + + @pytest.mark.asyncio + async def test_update_note_content_only(self) -> None: + from app.agents.note_agent import update_note + result = await update_note.ainvoke({ + "note_id": "n1", + "content": "# Updated content", + }) + data = json.loads(result) + assert data["action"] == "update_record" + assert data["data"]["id"] == "n1" + assert data["data"]["updates"]["content"] == "# Updated content" + assert "title" not in data["data"]["updates"] + + @pytest.mark.asyncio + async def test_update_note_empty_updates(self) -> None: + from app.agents.note_agent import update_note + result = await update_note.ainvoke({"note_id": "n1"}) + data = json.loads(result) + assert data["data"]["updates"] == {} + + @pytest.mark.asyncio + async def test_delete_note(self) -> None: + from app.agents.note_agent import delete_note + result = await delete_note.ainvoke({"note_id": "n1"}) + data = json.loads(result) + assert data["action"] == "delete_record" + assert data["table"] == "notes" + assert data["data"]["id"] == "n1" diff --git a/tests/test_execution_plan.py b/tests/test_execution_plan.py index 03e2db7..f468177 100644 --- a/tests/test_execution_plan.py +++ b/tests/test_execution_plan.py @@ -243,14 +243,14 @@ class TestPlanCache: class TestModuleSingletons: def test_template_registry_has_all_agent_defaults(self) -> None: - for agent in ("task_agent", "calendar_agent", "email_agent", "analytics_agent"): + for agent in ("task_agent", "checkpoint_agent", "project_agent", "note_agent"): assert template_registry.has(f"tpl_{agent}_default"), ( f"Missing template: tpl_{agent}_default" ) def test_template_registry_has_operation_templates(self) -> None: - assert template_registry.has("tpl_email_extract_action_items") - assert template_registry.has("tpl_analytics_weekly_summary") + assert template_registry.has("tpl_task_extract_from_project") + assert template_registry.has("tpl_note_weekly_summary") def test_template_registry_get_returns_non_empty_string(self) -> None: text = template_registry.get("tpl_task_agent_default") @@ -260,20 +260,20 @@ class TestModuleSingletons: def test_plan_cache_has_prebuilt_playbooks(self) -> None: assert len(plan_cache.get_all_playbooks()) >= 2 - def test_playbook_create_task_from_email(self) -> None: - plan = plan_cache.get_plan("create_task_from_email") + def test_playbook_create_tasks_from_project(self) -> None: + plan = plan_cache.get_plan("create_tasks_from_project") assert plan is not None - assert plan.agent == "email_agent" + assert plan.agent == "project_agent" assert len(plan.steps) == 2 - assert plan.steps[0].prompt_template == "tpl_email_extract_action_items" + assert plan.steps[0].prompt_template == "tpl_task_extract_from_project" assert plan.steps[1].data_from_step == 0 - def test_playbook_generate_weekly_report(self) -> None: - plan = plan_cache.get_plan("generate_weekly_report") + def test_playbook_generate_weekly_note(self) -> None: + plan = plan_cache.get_plan("generate_weekly_note") assert plan is not None - assert plan.agent == "analytics_agent" + assert plan.agent == "note_agent" assert len(plan.steps) == 2 - assert plan.steps[0].prompt_template == "tpl_analytics_weekly_summary" + assert plan.steps[0].prompt_template == "tpl_note_weekly_summary" assert plan.steps[1].data_from_step == 0 def test_playbook_steps_have_no_raw_prompt_text(self) -> None: diff --git a/tests/test_storage.py b/tests/test_storage.py new file mode 100644 index 0000000..3e6a7dc --- /dev/null +++ b/tests/test_storage.py @@ -0,0 +1,385 @@ +"""Tests for the storage layer: encryption, BlobStore, and VectorStore.""" + +from __future__ import annotations + +import base64 +import hashlib +import os +from unittest.mock import MagicMock, patch + +import boto3 +import pytest +from botocore.exceptions import ClientError +from moto import mock_aws + +from app.storage.encryption import reject_if_tampered, verify_checksum +from app.storage.blob_store import BlobStore +from app.storage.vector_store import VectorStore, _blob_to_vector +from app.schemas import VectorItem, VectorSearchResult + + +# ── Helpers ─────────────────────────────────────────────────────────── + +_BLOB = b"encrypted-payload-opaque-to-server" +_CHECKSUM = hashlib.sha256(_BLOB).hexdigest() +_BUCKET = "test-bucket" +_REGION = "us-east-1" + + +@pytest.fixture +def s3_bucket(): + """Create a mocked S3 bucket and expose its name.""" + with mock_aws(): + os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing") + os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing") + os.environ.setdefault("AWS_DEFAULT_REGION", _REGION) + client = boto3.client("s3", region_name=_REGION) + client.create_bucket(Bucket=_BUCKET) + with patch("app.storage.blob_store.settings") as mock_settings: + mock_settings.S3_BUCKET = _BUCKET + mock_settings.S3_REGION = _REGION + mock_settings.AWS_ACCESS_KEY_ID = "testing" + mock_settings.AWS_SECRET_ACCESS_KEY = "testing" + yield _BUCKET + + +def _pinecone_mock(): + """Return a mock Pinecone index with realistic return shapes.""" + mock_index = MagicMock() + mock_index.query.return_value = { + "matches": [ + { + "id": "v1", + "score": 0.95, + "metadata": { + "blob": base64.b64encode(b"result-blob").decode(), + "checksum": hashlib.sha256(b"result-blob").hexdigest(), + "user_id": "u1", + }, + } + ] + } + mock_pc = MagicMock() + mock_pc.return_value.Index.return_value = mock_index + return mock_pc, mock_index + + +# ── TestEncryption ──────────────────────────────────────────────────── + + +class TestEncryption: + def test_verify_checksum_correct(self) -> None: + assert verify_checksum(_BLOB, _CHECKSUM) is True + + def test_verify_checksum_wrong(self) -> None: + assert verify_checksum(_BLOB, "0" * 64) is False + + def test_verify_checksum_empty_checksum(self) -> None: + assert verify_checksum(_BLOB, "") is False + + def test_verify_checksum_empty_blob(self) -> None: + expected = hashlib.sha256(b"").hexdigest() + assert verify_checksum(b"", expected) is True + + def test_verify_checksum_tampered_blob(self) -> None: + tampered = _BLOB + b"\x00" + assert verify_checksum(tampered, _CHECKSUM) is False + + def test_reject_if_tampered_passes_when_valid(self) -> None: + # Should not raise + reject_if_tampered(_BLOB, _CHECKSUM) + + def test_reject_if_tampered_raises_400_on_mismatch(self) -> None: + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + reject_if_tampered(_BLOB, "bad" * 20) + assert exc_info.value.status_code == 400 + + def test_reject_if_tampered_detail_mentions_checksum(self) -> None: + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + reject_if_tampered(_BLOB, "bad" * 20) + assert "checksum" in exc_info.value.detail.lower() + + def test_checksum_is_sha256_hex(self) -> None: + cs = hashlib.sha256(_BLOB).hexdigest() + assert len(cs) == 64 + assert all(c in "0123456789abcdef" for c in cs) + + +# ── TestBlobStore ───────────────────────────────────────────────────── + + +class TestBlobStore: + @pytest.mark.asyncio + async def test_upload_returns_correct_key(self, s3_bucket: str) -> None: + store = BlobStore() + key = await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM) + assert key == "u1/tasks/r1" + + @pytest.mark.asyncio + async def test_upload_object_exists_in_s3(self, s3_bucket: str) -> None: + store = BlobStore() + await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM) + # Verify by downloading — no exception means object exists + retrieved = await store.download("u1", "u1/tasks/r1") + assert retrieved == _BLOB + + @pytest.mark.asyncio + async def test_download_retrieves_same_bytes(self, s3_bucket: str) -> None: + store = BlobStore() + await store.upload("u1", "notes", "n1", b"note-data", hashlib.sha256(b"note-data").hexdigest()) + result = await store.download("u1", "u1/notes/n1") + assert result == b"note-data" + + @pytest.mark.asyncio + async def test_delete_removes_object(self, s3_bucket: str) -> None: + store = BlobStore() + await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM) + await store.delete("u1", "u1/tasks/r1") + with pytest.raises(ClientError) as exc_info: + await store.download("u1", "u1/tasks/r1") + assert exc_info.value.response["Error"]["Code"] == "NoSuchKey" + + @pytest.mark.asyncio + async def test_delete_is_idempotent(self, s3_bucket: str) -> None: + store = BlobStore() + # Delete a key that never existed — should not raise + await store.delete("u1", "u1/tasks/nonexistent") + + @pytest.mark.asyncio + async def test_list_keys_returns_correct_keys(self, s3_bucket: str) -> None: + store = BlobStore() + await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM) + await store.upload("u1", "tasks", "r2", _BLOB, _CHECKSUM) + keys = await store.list_keys("u1", "tasks") + assert set(keys) == {"u1/tasks/r1", "u1/tasks/r2"} + + @pytest.mark.asyncio + async def test_list_keys_scoped_to_table(self, s3_bucket: str) -> None: + store = BlobStore() + await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM) + await store.upload("u1", "notes", "n1", _BLOB, _CHECKSUM) + keys = await store.list_keys("u1", "tasks") + assert "u1/notes/n1" not in keys + assert "u1/tasks/r1" in keys + + @pytest.mark.asyncio + async def test_list_keys_no_cross_user_leakage(self, s3_bucket: str) -> None: + store = BlobStore() + await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM) + await store.upload("u2", "tasks", "r1", _BLOB, _CHECKSUM) + keys_u1 = await store.list_keys("u1", "tasks") + assert "u2/tasks/r1" not in keys_u1 + + @pytest.mark.asyncio + async def test_list_keys_empty_table(self, s3_bucket: str) -> None: + store = BlobStore() + keys = await store.list_keys("u1", "tasks") + assert keys == [] + + @pytest.mark.asyncio + async def test_upload_uses_sse_s3_encryption(self, s3_bucket: str) -> None: + store = BlobStore() + await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM) + # Verify S3 metadata was set — check via head_object + with patch("app.storage.blob_store.settings") as mock_settings: + mock_settings.S3_BUCKET = _BUCKET + mock_settings.S3_REGION = _REGION + mock_settings.AWS_ACCESS_KEY_ID = "testing" + mock_settings.AWS_SECRET_ACCESS_KEY = "testing" + client = boto3.client("s3", region_name=_REGION) + response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1") + assert response.get("ServerSideEncryption") == "AES256" + + @pytest.mark.asyncio + async def test_upload_stores_checksum_in_metadata(self, s3_bucket: str) -> None: + store = BlobStore() + await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM) + client = boto3.client("s3", region_name=_REGION) + response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1") + assert response["Metadata"]["checksum"] == _CHECKSUM + + +# ── _blob_to_vector helper ──────────────────────────────────────────── + + +class TestBlobToVector: + def test_returns_32_floats(self) -> None: + v = _blob_to_vector(b"test") + assert len(v) == 32 + + def test_all_values_in_range(self) -> None: + v = _blob_to_vector(b"test") + assert all(-1.0 <= x <= 1.0 for x in v) + + def test_deterministic(self) -> None: + assert _blob_to_vector(b"same") == _blob_to_vector(b"same") + + def test_different_blobs_different_vectors(self) -> None: + assert _blob_to_vector(b"aaa") != _blob_to_vector(b"bbb") + + +# ── TestVectorStorePinecone ─────────────────────────────────────────── + + +class TestVectorStorePinecone: + def _store(self) -> VectorStore: + store = VectorStore() + store._use_pinecone = lambda: True # type: ignore[method-assign] + return store + + @pytest.mark.asyncio + async def test_upsert_calls_index_upsert(self) -> None: + mock_pc, mock_index = _pinecone_mock() + with patch("app.storage.vector_store.Pinecone", mock_pc): + store = self._store() + items = [VectorItem(id="v1", blob=b"enc-blob", checksum=hashlib.sha256(b"enc-blob").hexdigest())] + await store.upsert("u1", items) + mock_index.upsert.assert_called_once() + call_kwargs = mock_index.upsert.call_args[1] + assert call_kwargs.get("namespace") == "u1" + + @pytest.mark.asyncio + async def test_upsert_encodes_blob_as_base64_in_metadata(self) -> None: + mock_pc, mock_index = _pinecone_mock() + with patch("app.storage.vector_store.Pinecone", mock_pc): + store = self._store() + items = [VectorItem(id="v1", blob=b"secret", checksum=hashlib.sha256(b"secret").hexdigest())] + await store.upsert("u1", items) + vectors_arg = mock_index.upsert.call_args[1]["vectors"] + assert vectors_arg[0]["metadata"]["blob"] == base64.b64encode(b"secret").decode() + + @pytest.mark.asyncio + async def test_search_calls_index_query(self) -> None: + mock_pc, mock_index = _pinecone_mock() + with patch("app.storage.vector_store.Pinecone", mock_pc): + store = self._store() + await store.search("u1", b"query-blob", top_k=5) + mock_index.query.assert_called_once() + query_kwargs = mock_index.query.call_args[1] + assert query_kwargs.get("namespace") == "u1" + assert query_kwargs.get("top_k") == 5 + assert query_kwargs.get("include_metadata") is True + + @pytest.mark.asyncio + async def test_search_returns_vector_search_results(self) -> None: + mock_pc, mock_index = _pinecone_mock() + with patch("app.storage.vector_store.Pinecone", mock_pc): + store = self._store() + results = await store.search("u1", b"query", top_k=10) + assert len(results) == 1 + assert isinstance(results[0], VectorSearchResult) + assert results[0].id == "v1" + assert results[0].score == 0.95 + assert results[0].blob == b"result-blob" + + @pytest.mark.asyncio + async def test_search_uses_derived_query_vector(self) -> None: + mock_pc, mock_index = _pinecone_mock() + with patch("app.storage.vector_store.Pinecone", mock_pc): + store = self._store() + await store.search("u1", b"query-blob", top_k=3) + expected_vector = _blob_to_vector(b"query-blob") + actual_vector = mock_index.query.call_args[1].get("vector") + assert actual_vector == expected_vector + + @pytest.mark.asyncio + async def test_delete_calls_index_delete(self) -> None: + mock_pc, mock_index = _pinecone_mock() + with patch("app.storage.vector_store.Pinecone", mock_pc): + store = self._store() + await store.delete("u1", ["v1", "v2"]) + mock_index.delete.assert_called_once() + delete_kwargs = mock_index.delete.call_args[1] + assert delete_kwargs.get("namespace") == "u1" + assert set(delete_kwargs.get("ids", [])) == {"v1", "v2"} + + +# ── TestVectorStoreQdrant ───────────────────────────────────────────── + + +class TestVectorStoreQdrant: + def _store(self) -> VectorStore: + store = VectorStore() + store._use_pinecone = lambda: False # type: ignore[method-assign] + return store + + def _qdrant_mock(self) -> MagicMock: + mock_hit = MagicMock() + mock_hit.id = "v1" + mock_hit.score = 0.88 + mock_hit.payload = { + "blob": base64.b64encode(b"qdrant-result").decode(), + "user_id": "u1", + } + mock_client = MagicMock() + mock_client.search.return_value = [mock_hit] + return mock_client + + @pytest.mark.asyncio + async def test_upsert_calls_client_upsert(self) -> None: + mock_client = MagicMock() + with patch("app.storage.vector_store.QdrantClient", return_value=mock_client): + store = self._store() + items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())] + await store.upsert("u1", items) + mock_client.upsert.assert_called_once() + + @pytest.mark.asyncio + async def test_upsert_uses_correct_collection(self) -> None: + mock_client = MagicMock() + with patch("app.storage.vector_store.QdrantClient", return_value=mock_client): + store = self._store() + items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())] + await store.upsert("u1", items) + call_kwargs = mock_client.upsert.call_args[1] + assert call_kwargs["collection_name"] == "adiuva_vectors" + + @pytest.mark.asyncio + async def test_search_calls_client_search(self) -> None: + mock_client = self._qdrant_mock() + with patch("app.storage.vector_store.QdrantClient", return_value=mock_client): + store = self._store() + await store.search("u1", b"query", top_k=5) + mock_client.search.assert_called_once() + + @pytest.mark.asyncio + async def test_search_passes_limit(self) -> None: + mock_client = self._qdrant_mock() + with patch("app.storage.vector_store.QdrantClient", return_value=mock_client): + store = self._store() + await store.search("u1", b"query", top_k=7) + call_kwargs = mock_client.search.call_args[1] + assert call_kwargs.get("limit") == 7 + + @pytest.mark.asyncio + async def test_search_returns_vector_search_results(self) -> None: + mock_client = self._qdrant_mock() + with patch("app.storage.vector_store.QdrantClient", return_value=mock_client): + store = self._store() + results = await store.search("u1", b"query", top_k=5) + assert len(results) == 1 + assert isinstance(results[0], VectorSearchResult) + assert results[0].id == "v1" + assert results[0].score == 0.88 + assert results[0].blob == b"qdrant-result" + + @pytest.mark.asyncio + async def test_delete_calls_client_delete(self) -> None: + mock_client = MagicMock() + with patch("app.storage.vector_store.QdrantClient", return_value=mock_client): + store = self._store() + await store.delete("u1", ["v1", "v2"]) + mock_client.delete.assert_called_once() + + @pytest.mark.asyncio + async def test_delete_uses_correct_collection(self) -> None: + mock_client = MagicMock() + with patch("app.storage.vector_store.QdrantClient", return_value=mock_client): + store = self._store() + await store.delete("u1", ["v1"]) + call_kwargs = mock_client.delete.call_args[1] + assert call_kwargs["collection_name"] == "adiuva_vectors"