Refactor tests for execution plan and add comprehensive storage tests
- Updated `TestModuleSingletons` in `test_execution_plan.py` to reflect new agent templates and playbook names. - Changed assertions in playbook tests to match updated templates and agents. - Introduced `test_storage.py` to cover the storage layer, including encryption, BlobStore, and VectorStore functionalities. - Added tests for S3 interactions, ensuring upload, download, delete, and list operations work as expected. - Implemented mock tests for Pinecone and Qdrant vector stores to validate upsert, search, and delete operations.
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Unit tests for all four chat agents with mocked LLM."""
|
||||
"""Unit tests for the four domain-specific chat agents with mocked LLM."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -9,9 +9,9 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
import app.agents # noqa: F401 — triggers @registry.register decorators
|
||||
from app.agents.analytics_agent import AnalyticsAgent
|
||||
from app.agents.calendar_agent import CalendarAgent
|
||||
from app.agents.email_agent import EmailAgent
|
||||
from app.agents.checkpoint_agent import CheckpointAgent
|
||||
from app.agents.note_agent import NoteAgent
|
||||
from app.agents.project_agent import ProjectAgent
|
||||
from app.agents.task_agent import TaskAgent
|
||||
from app.core.agent_registry import registry
|
||||
|
||||
@@ -59,15 +59,15 @@ def _mock_llm_with_tool_call(
|
||||
class TestAgentRegistration:
|
||||
def test_all_agents_registered(self) -> None:
|
||||
names = {a["name"] for a in registry.list_agents()}
|
||||
assert {"task_agent", "calendar_agent", "email_agent", "analytics_agent"}.issubset(
|
||||
names
|
||||
)
|
||||
assert {
|
||||
"task_agent", "checkpoint_agent", "project_agent", "note_agent"
|
||||
}.issubset(names)
|
||||
|
||||
def test_registry_returns_correct_types(self) -> None:
|
||||
assert isinstance(registry.get("task_agent"), TaskAgent)
|
||||
assert isinstance(registry.get("calendar_agent"), CalendarAgent)
|
||||
assert isinstance(registry.get("email_agent"), EmailAgent)
|
||||
assert isinstance(registry.get("analytics_agent"), AnalyticsAgent)
|
||||
assert isinstance(registry.get("checkpoint_agent"), CheckpointAgent)
|
||||
assert isinstance(registry.get("project_agent"), ProjectAgent)
|
||||
assert isinstance(registry.get("note_agent"), NoteAgent)
|
||||
|
||||
def test_descriptions_present(self) -> None:
|
||||
for agent_info in registry.list_agents():
|
||||
@@ -82,14 +82,23 @@ class TestTaskAgent:
|
||||
assert TaskAgent().get_name() == "task_agent"
|
||||
|
||||
def test_description(self) -> None:
|
||||
assert TaskAgent().get_description() == "Manages tasks: create, update, list, suggest"
|
||||
assert TaskAgent().get_description() == "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
||||
|
||||
def test_get_tools_count(self) -> None:
|
||||
assert len(TaskAgent().get_tools()) == 4
|
||||
assert len(TaskAgent().get_tools()) == 8
|
||||
|
||||
def test_tool_names(self) -> None:
|
||||
names = {t.name for t in TaskAgent().get_tools()}
|
||||
assert names == {"create_task", "update_task", "list_tasks", "suggest_tasks"}
|
||||
assert names == {
|
||||
"list_tasks",
|
||||
"create_task",
|
||||
"update_task",
|
||||
"delete_task",
|
||||
"list_tasks_due_today",
|
||||
"list_task_comments",
|
||||
"add_task_comment",
|
||||
"delete_task_comment",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_returns_string(self) -> None:
|
||||
@@ -111,10 +120,10 @@ class TestTaskAgent:
|
||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||
"create_task",
|
||||
{"title": "Buy groceries", "priority": "low"},
|
||||
"Task 'Buy groceries' created with low priority.",
|
||||
"Task 'Buy groceries' created.",
|
||||
)
|
||||
result = await TaskAgent().handle("add a grocery task", {})
|
||||
assert result == "Task 'Buy groceries' created with low priority."
|
||||
assert result == "Task 'Buy groceries' created."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_empty_context(self) -> None:
|
||||
@@ -123,20 +132,11 @@ class TestTaskAgent:
|
||||
result = await TaskAgent().handle("help", {})
|
||||
assert isinstance(result, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_partial_context(self) -> None:
|
||||
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Done.")
|
||||
result = await TaskAgent().handle("list tasks", {"user_profile": {"id": "u1"}})
|
||||
assert isinstance(result, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_rich_context(self) -> None:
|
||||
context = {
|
||||
"user_profile": {"id": "u1", "tier": "pro"},
|
||||
"recent_tasks": [{"id": "t1", "title": "Old task"}],
|
||||
"relevant_documents": ["doc1"],
|
||||
"extra_plugin_data": {"batch_id": "b1"},
|
||||
}
|
||||
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Tasks listed.")
|
||||
@@ -146,244 +146,475 @@ class TestTaskAgent:
|
||||
|
||||
class TestTaskAgentTools:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_returns_valid_json(self) -> None:
|
||||
async def test_list_tasks_defaults(self) -> None:
|
||||
from app.agents.task_agent import list_tasks
|
||||
result = await list_tasks.ainvoke({})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "list"
|
||||
assert data["table"] == "tasks"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_with_status_filter(self) -> None:
|
||||
from app.agents.task_agent import list_tasks
|
||||
result = await list_tasks.ainvoke({"status": "done"})
|
||||
data = json.loads(result)
|
||||
assert data["filters"]["status"] == "done"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_defaults(self) -> None:
|
||||
from app.agents.task_agent import create_task
|
||||
result = await create_task.ainvoke({"title": "Test task", "priority": "high"})
|
||||
result = await create_task.ainvoke({"title": "Test task"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "create_record"
|
||||
assert data["table"] == "tasks"
|
||||
assert data["data"]["title"] == "Test task"
|
||||
assert data["data"]["priority"] == "high"
|
||||
assert data["data"]["status"] == "todo"
|
||||
assert data["data"]["priority"] == "medium"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task_returns_valid_json(self) -> None:
|
||||
async def test_create_task_with_all_fields(self) -> None:
|
||||
from app.agents.task_agent import create_task
|
||||
result = await create_task.ainvoke({
|
||||
"title": "Deploy",
|
||||
"priority": "high",
|
||||
"status": "in_progress",
|
||||
"project_id": "p1",
|
||||
"is_ai_suggested": 1,
|
||||
})
|
||||
data = json.loads(result)
|
||||
assert data["data"]["priority"] == "high"
|
||||
assert data["data"]["status"] == "in_progress"
|
||||
assert data["data"]["projectId"] == "p1"
|
||||
assert data["data"]["isAiSuggested"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task_with_status(self) -> None:
|
||||
from app.agents.task_agent import update_task
|
||||
result = await update_task.ainvoke(
|
||||
{"task_id": "t1", "updates": '{"priority": "urgent"}'}
|
||||
)
|
||||
result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "update_record"
|
||||
assert data["data"]["id"] == "t1"
|
||||
assert data["data"]["updates"]["status"] == "done"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_returns_valid_json(self) -> None:
|
||||
from app.agents.task_agent import list_tasks
|
||||
result = await list_tasks.ainvoke({"status": "open"})
|
||||
async def test_update_task_empty_updates(self) -> None:
|
||||
from app.agents.task_agent import update_task
|
||||
result = await update_task.ainvoke({"task_id": "t1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "list"
|
||||
assert data["data"]["updates"] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_task(self) -> None:
|
||||
from app.agents.task_agent import delete_task
|
||||
result = await delete_task.ainvoke({"task_id": "t1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "delete_record"
|
||||
assert data["table"] == "tasks"
|
||||
assert data["data"]["id"] == "t1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_due_today(self) -> None:
|
||||
from app.agents.task_agent import list_tasks_due_today
|
||||
result = await list_tasks_due_today.ainvoke({})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "list_due_today"
|
||||
assert data["table"] == "tasks"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggest_tasks_returns_valid_json(self) -> None:
|
||||
from app.agents.task_agent import suggest_tasks
|
||||
result = await suggest_tasks.ainvoke({"context": "lots of meetings this week"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "suggest"
|
||||
|
||||
|
||||
# ── CalendarAgent ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCalendarAgent:
|
||||
def test_name(self) -> None:
|
||||
assert CalendarAgent().get_name() == "calendar_agent"
|
||||
|
||||
def test_description(self) -> None:
|
||||
assert CalendarAgent().get_description() == "Calendar management: events, conflicts, scheduling"
|
||||
|
||||
def test_get_tools_count(self) -> None:
|
||||
assert len(CalendarAgent().get_tools()) == 3
|
||||
|
||||
def test_tool_names(self) -> None:
|
||||
names = {t.name for t in CalendarAgent().get_tools()}
|
||||
assert names == {"list_events", "detect_conflicts", "suggest_reschedule"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_no_tool_calls(self) -> None:
|
||||
with patch("app.agents.calendar_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("No conflicts found.")
|
||||
result = await CalendarAgent().handle("check my schedule", {})
|
||||
assert result == "No conflicts found."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_with_list_events_tool_call(self) -> None:
|
||||
with patch("app.agents.calendar_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||
"list_events",
|
||||
{"date_range": "2024-01-01/2024-01-07"},
|
||||
"You have 3 events next week.",
|
||||
)
|
||||
result = await CalendarAgent().handle("what events do I have?", {})
|
||||
assert result == "You have 3 events next week."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_empty_context(self) -> None:
|
||||
with patch("app.agents.calendar_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Done.")
|
||||
result = await CalendarAgent().handle("reschedule meeting", {})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
class TestCalendarAgentTools:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_events_returns_valid_json(self) -> None:
|
||||
from app.agents.calendar_agent import list_events
|
||||
result = await list_events.ainvoke({"date_range": "2024-01-01/2024-01-07"})
|
||||
async def test_list_task_comments(self) -> None:
|
||||
from app.agents.task_agent import list_task_comments
|
||||
result = await list_task_comments.ainvoke({"task_id": "t1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "list"
|
||||
assert data["table"] == "events"
|
||||
assert data["filters"]["date_range"] == "2024-01-01/2024-01-07"
|
||||
assert data["table"] == "taskComments"
|
||||
assert data["filters"]["taskId"] == "t1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detect_conflicts_returns_valid_json(self) -> None:
|
||||
from app.agents.calendar_agent import detect_conflicts
|
||||
result = await detect_conflicts.ainvoke({"events": "[]"})
|
||||
async def test_add_task_comment(self) -> None:
|
||||
from app.agents.task_agent import add_task_comment
|
||||
result = await add_task_comment.ainvoke({
|
||||
"task_id": "t1",
|
||||
"author": "Alice",
|
||||
"content": "Looks good!",
|
||||
})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "analyse"
|
||||
assert data["action"] == "create_record"
|
||||
assert data["table"] == "taskComments"
|
||||
assert data["data"]["taskId"] == "t1"
|
||||
assert data["data"]["author"] == "Alice"
|
||||
assert data["data"]["content"] == "Looks good!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggest_reschedule_returns_valid_json(self) -> None:
|
||||
from app.agents.calendar_agent import suggest_reschedule
|
||||
result = await suggest_reschedule.ainvoke({"conflict": '{"event": "standup"}'})
|
||||
async def test_delete_task_comment(self) -> None:
|
||||
from app.agents.task_agent import delete_task_comment
|
||||
result = await delete_task_comment.ainvoke({"comment_id": "c1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "suggest_reschedule"
|
||||
assert data["action"] == "delete_record"
|
||||
assert data["table"] == "taskComments"
|
||||
assert data["data"]["id"] == "c1"
|
||||
|
||||
|
||||
# ── EmailAgent ────────────────────────────────────────────────────────
|
||||
# ── CheckpointAgent ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEmailAgent:
|
||||
class TestCheckpointAgent:
|
||||
def test_name(self) -> None:
|
||||
assert EmailAgent().get_name() == "email_agent"
|
||||
assert CheckpointAgent().get_name() == "checkpoint_agent"
|
||||
|
||||
def test_description(self) -> None:
|
||||
assert EmailAgent().get_description() == "Email analysis: classify, extract actions, draft responses"
|
||||
assert CheckpointAgent().get_description() == "Manages project checkpoints (milestones): list, create, update, delete"
|
||||
|
||||
def test_get_tools_count(self) -> None:
|
||||
assert len(EmailAgent().get_tools()) == 3
|
||||
assert len(CheckpointAgent().get_tools()) == 4
|
||||
|
||||
def test_tool_names(self) -> None:
|
||||
names = {t.name for t in EmailAgent().get_tools()}
|
||||
assert names == {"classify_email", "extract_action_items", "draft_response"}
|
||||
names = {t.name for t in CheckpointAgent().get_tools()}
|
||||
assert names == {"list_checkpoints", "create_checkpoint", "update_checkpoint", "delete_checkpoint"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_no_tool_calls(self) -> None:
|
||||
with patch("app.agents.email_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Email classified as action_required.")
|
||||
result = await EmailAgent().handle("classify this email", {})
|
||||
assert result == "Email classified as action_required."
|
||||
with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("No checkpoints found.")
|
||||
result = await CheckpointAgent().handle("list checkpoints", {})
|
||||
assert result == "No checkpoints found."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_with_classify_tool_call(self) -> None:
|
||||
with patch("app.agents.email_agent.ChatOpenAI") as mock_cls:
|
||||
async def test_handle_with_create_tool_call(self) -> None:
|
||||
with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||
"classify_email",
|
||||
{"metadata": '{"subject": "URGENT: action needed"}'},
|
||||
"This email requires immediate action.",
|
||||
"create_checkpoint",
|
||||
{"project_id": "p1", "title": "MVP Launch", "date": 1700000000000},
|
||||
"Checkpoint 'MVP Launch' created.",
|
||||
)
|
||||
result = await EmailAgent().handle("what is this email about?", {})
|
||||
assert result == "This email requires immediate action."
|
||||
result = await CheckpointAgent().handle("add MVP checkpoint", {})
|
||||
assert result == "Checkpoint 'MVP Launch' created."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_empty_context(self) -> None:
|
||||
with patch("app.agents.email_agent.ChatOpenAI") as mock_cls:
|
||||
with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Done.")
|
||||
result = await EmailAgent().handle("draft a reply", {})
|
||||
result = await CheckpointAgent().handle("show milestones", {})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
class TestEmailAgentTools:
|
||||
class TestCheckpointAgentTools:
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_email_returns_valid_json(self) -> None:
|
||||
from app.agents.email_agent import classify_email
|
||||
result = await classify_email.ainvoke({"metadata": '{"subject": "Meeting"}' })
|
||||
async def test_list_checkpoints_no_project(self) -> None:
|
||||
from app.agents.checkpoint_agent import list_checkpoints
|
||||
result = await list_checkpoints.ainvoke({})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "classify"
|
||||
assert "result" in data
|
||||
assert "category" in data["result"]
|
||||
assert data["action"] == "list"
|
||||
assert data["table"] == "checkpoints"
|
||||
assert data["filters"]["projectId"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_action_items_returns_valid_json(self) -> None:
|
||||
from app.agents.email_agent import extract_action_items
|
||||
result = await extract_action_items.ainvoke({"metadata": '{"subject": "Follow up"}'})
|
||||
async def test_list_checkpoints_with_project(self) -> None:
|
||||
from app.agents.checkpoint_agent import list_checkpoints
|
||||
result = await list_checkpoints.ainvoke({"project_id": "p1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "extract"
|
||||
assert "action_items" in data["result"]
|
||||
assert data["filters"]["projectId"] == "p1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_draft_response_returns_valid_json(self) -> None:
|
||||
from app.agents.email_agent import draft_response
|
||||
result = await draft_response.ainvoke({"thread_context": '{"thread_id": "t1"}'})
|
||||
async def test_create_checkpoint(self) -> None:
|
||||
from app.agents.checkpoint_agent import create_checkpoint
|
||||
result = await create_checkpoint.ainvoke({
|
||||
"project_id": "p1",
|
||||
"title": "Beta release",
|
||||
"date": 1700000000000,
|
||||
})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "draft"
|
||||
assert data["action"] == "create_record"
|
||||
assert data["table"] == "checkpoints"
|
||||
assert data["data"]["projectId"] == "p1"
|
||||
assert data["data"]["title"] == "Beta release"
|
||||
assert data["data"]["date"] == 1700000000000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_checkpoint_ai_suggested(self) -> None:
|
||||
from app.agents.checkpoint_agent import create_checkpoint
|
||||
result = await create_checkpoint.ainvoke({
|
||||
"project_id": "p1",
|
||||
"title": "Review",
|
||||
"date": 1700000000000,
|
||||
"is_ai_suggested": 1,
|
||||
})
|
||||
data = json.loads(result)
|
||||
assert data["data"]["isAiSuggested"] == 1
|
||||
assert data["data"]["isApproved"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_checkpoint_approve(self) -> None:
|
||||
from app.agents.checkpoint_agent import update_checkpoint
|
||||
result = await update_checkpoint.ainvoke({
|
||||
"checkpoint_id": "c1",
|
||||
"is_approved": 1,
|
||||
})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "update_record"
|
||||
assert data["data"]["id"] == "c1"
|
||||
assert data["data"]["updates"]["isApproved"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_checkpoint_empty_updates(self) -> None:
|
||||
from app.agents.checkpoint_agent import update_checkpoint
|
||||
result = await update_checkpoint.ainvoke({"checkpoint_id": "c1"})
|
||||
data = json.loads(result)
|
||||
assert data["data"]["updates"] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_checkpoint(self) -> None:
|
||||
from app.agents.checkpoint_agent import delete_checkpoint
|
||||
result = await delete_checkpoint.ainvoke({"checkpoint_id": "c1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "delete_record"
|
||||
assert data["table"] == "checkpoints"
|
||||
assert data["data"]["id"] == "c1"
|
||||
|
||||
|
||||
# ── AnalyticsAgent ────────────────────────────────────────────────────
|
||||
# ── ProjectAgent ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAnalyticsAgent:
|
||||
class TestProjectAgent:
|
||||
def test_name(self) -> None:
|
||||
assert AnalyticsAgent().get_name() == "analytics_agent"
|
||||
assert ProjectAgent().get_name() == "project_agent"
|
||||
|
||||
def test_description(self) -> None:
|
||||
assert AnalyticsAgent().get_description() == "Workspace analytics: metrics, reports, trends"
|
||||
assert ProjectAgent().get_description() == "Manages projects: list, get, create, update, archive, delete"
|
||||
|
||||
def test_get_tools_count(self) -> None:
|
||||
assert len(AnalyticsAgent().get_tools()) == 3
|
||||
assert len(ProjectAgent().get_tools()) == 6
|
||||
|
||||
def test_tool_names(self) -> None:
|
||||
names = {t.name for t in AnalyticsAgent().get_tools()}
|
||||
assert names == {"calculate_metrics", "generate_report", "trend_analysis"}
|
||||
names = {t.name for t in ProjectAgent().get_tools()}
|
||||
assert names == {
|
||||
"list_projects",
|
||||
"list_all_projects",
|
||||
"get_project",
|
||||
"create_project",
|
||||
"update_project",
|
||||
"delete_project",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_no_tool_calls(self) -> None:
|
||||
with patch("app.agents.analytics_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Completion rate is 78%.")
|
||||
result = await AnalyticsAgent().handle("show my metrics", {})
|
||||
assert result == "Completion rate is 78%."
|
||||
with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Project Alpha is active.")
|
||||
result = await ProjectAgent().handle("show my projects", {})
|
||||
assert result == "Project Alpha is active."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_with_generate_report_tool_call(self) -> None:
|
||||
with patch("app.agents.analytics_agent.ChatOpenAI") as mock_cls:
|
||||
async def test_handle_with_create_project_tool_call(self) -> None:
|
||||
with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||
"generate_report",
|
||||
{"period": "last_7_days", "data": "[]"},
|
||||
"Weekly report: 12 tasks completed, 2 overdue.",
|
||||
"create_project",
|
||||
{"name": "Pippo"},
|
||||
"Project 'Pippo' created.",
|
||||
)
|
||||
result = await AnalyticsAgent().handle("weekly report", {})
|
||||
assert result == "Weekly report: 12 tasks completed, 2 overdue."
|
||||
result = await ProjectAgent().handle("create project Pippo", {})
|
||||
assert result == "Project 'Pippo' created."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_empty_context(self) -> None:
|
||||
with patch("app.agents.analytics_agent.ChatOpenAI") as mock_cls:
|
||||
with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Done.")
|
||||
result = await AnalyticsAgent().handle("analyse trends", {})
|
||||
result = await ProjectAgent().handle("archive old project", {})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
class TestAnalyticsAgentTools:
|
||||
class TestProjectAgentTools:
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_metrics_returns_valid_json(self) -> None:
|
||||
from app.agents.analytics_agent import calculate_metrics
|
||||
result = await calculate_metrics.ainvoke({"task_data": "[]"})
|
||||
async def test_list_projects_defaults(self) -> None:
|
||||
from app.agents.project_agent import list_projects
|
||||
result = await list_projects.ainvoke({})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "calculate"
|
||||
assert "result" in data
|
||||
assert "completion_rate" in data["result"]
|
||||
assert data["action"] == "list"
|
||||
assert data["table"] == "projects"
|
||||
assert data["filters"]["includeArchived"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_report_returns_valid_json(self) -> None:
|
||||
from app.agents.analytics_agent import generate_report
|
||||
result = await generate_report.ainvoke({"period": "last_7_days", "data": "[]"})
|
||||
async def test_list_projects_include_archived(self) -> None:
|
||||
from app.agents.project_agent import list_projects
|
||||
result = await list_projects.ainvoke({"include_archived": 1})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "report"
|
||||
assert data["period"] == "last_7_days"
|
||||
assert data["filters"]["includeArchived"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trend_analysis_returns_valid_json(self) -> None:
|
||||
from app.agents.analytics_agent import trend_analysis
|
||||
result = await trend_analysis.ainvoke({"data_points": "[]"})
|
||||
async def test_list_all_projects(self) -> None:
|
||||
from app.agents.project_agent import list_all_projects
|
||||
result = await list_all_projects.ainvoke({})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "trend"
|
||||
assert "result" in data
|
||||
assert "anomalies" in data["result"]
|
||||
assert data["action"] == "list_all"
|
||||
assert data["table"] == "projects"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_project(self) -> None:
|
||||
from app.agents.project_agent import get_project
|
||||
result = await get_project.ainvoke({"project_id": "p1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "get"
|
||||
assert data["table"] == "projects"
|
||||
assert data["data"]["id"] == "p1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_project_name_only(self) -> None:
|
||||
from app.agents.project_agent import create_project
|
||||
result = await create_project.ainvoke({"name": "Alpha"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "create_record"
|
||||
assert data["data"]["name"] == "Alpha"
|
||||
assert data["data"]["clientId"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_project_with_client(self) -> None:
|
||||
from app.agents.project_agent import create_project
|
||||
result = await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
|
||||
data = json.loads(result)
|
||||
assert data["data"]["clientId"] == "cl1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_project_archive(self) -> None:
|
||||
from app.agents.project_agent import update_project
|
||||
result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "update_record"
|
||||
assert data["data"]["id"] == "p1"
|
||||
assert data["data"]["updates"]["status"] == "archived"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_project_empty_updates(self) -> None:
|
||||
from app.agents.project_agent import update_project
|
||||
result = await update_project.ainvoke({"project_id": "p1"})
|
||||
data = json.loads(result)
|
||||
assert data["data"]["updates"] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_project(self) -> None:
|
||||
from app.agents.project_agent import delete_project
|
||||
result = await delete_project.ainvoke({"project_id": "p1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "delete_record"
|
||||
assert data["data"]["id"] == "p1"
|
||||
|
||||
|
||||
# ── NoteAgent ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestNoteAgent:
|
||||
def test_name(self) -> None:
|
||||
assert NoteAgent().get_name() == "note_agent"
|
||||
|
||||
def test_description(self) -> None:
|
||||
assert NoteAgent().get_description() == "Manages notes: list, get, create, update, delete"
|
||||
|
||||
def test_get_tools_count(self) -> None:
|
||||
assert len(NoteAgent().get_tools()) == 5
|
||||
|
||||
def test_tool_names(self) -> None:
|
||||
names = {t.name for t in NoteAgent().get_tools()}
|
||||
assert names == {"list_notes", "get_note", "create_note", "update_note", "delete_note"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_no_tool_calls(self) -> None:
|
||||
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Note created.")
|
||||
result = await NoteAgent().handle("create a note", {})
|
||||
assert result == "Note created."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_with_create_note_tool_call(self) -> None:
|
||||
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||
"create_note",
|
||||
{"title": "Daily log", "content": "# Today\nAll good."},
|
||||
"Note 'Daily log' created.",
|
||||
)
|
||||
result = await NoteAgent().handle("log today's progress", {})
|
||||
assert result == "Note 'Daily log' created."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_empty_context(self) -> None:
|
||||
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Done.")
|
||||
result = await NoteAgent().handle("show notes", {})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
class TestNoteAgentTools:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_notes_no_project(self) -> None:
|
||||
from app.agents.note_agent import list_notes
|
||||
result = await list_notes.ainvoke({})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "list"
|
||||
assert data["table"] == "notes"
|
||||
assert data["filters"]["projectId"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_notes_with_project(self) -> None:
|
||||
from app.agents.note_agent import list_notes
|
||||
result = await list_notes.ainvoke({"project_id": "p1"})
|
||||
data = json.loads(result)
|
||||
assert data["filters"]["projectId"] == "p1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_note(self) -> None:
|
||||
from app.agents.note_agent import get_note
|
||||
result = await get_note.ainvoke({"note_id": "n1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "get"
|
||||
assert data["table"] == "notes"
|
||||
assert data["data"]["id"] == "n1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_note_minimal(self) -> None:
|
||||
from app.agents.note_agent import create_note
|
||||
result = await create_note.ainvoke({
|
||||
"title": "Daily log",
|
||||
"content": "# Today\nAll good.",
|
||||
})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "create_record"
|
||||
assert data["table"] == "notes"
|
||||
assert data["data"]["title"] == "Daily log"
|
||||
assert data["data"]["content"] == "# Today\nAll good."
|
||||
assert data["data"]["projectId"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_note_with_project(self) -> None:
|
||||
from app.agents.note_agent import create_note
|
||||
result = await create_note.ainvoke({
|
||||
"title": "Sprint notes",
|
||||
"content": "## Sprint 1",
|
||||
"project_id": "p1",
|
||||
})
|
||||
data = json.loads(result)
|
||||
assert data["data"]["projectId"] == "p1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_note_content_only(self) -> None:
|
||||
from app.agents.note_agent import update_note
|
||||
result = await update_note.ainvoke({
|
||||
"note_id": "n1",
|
||||
"content": "# Updated content",
|
||||
})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "update_record"
|
||||
assert data["data"]["id"] == "n1"
|
||||
assert data["data"]["updates"]["content"] == "# Updated content"
|
||||
assert "title" not in data["data"]["updates"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_note_empty_updates(self) -> None:
|
||||
from app.agents.note_agent import update_note
|
||||
result = await update_note.ainvoke({"note_id": "n1"})
|
||||
data = json.loads(result)
|
||||
assert data["data"]["updates"] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_note(self) -> None:
|
||||
from app.agents.note_agent import delete_note
|
||||
result = await delete_note.ainvoke({"note_id": "n1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "delete_record"
|
||||
assert data["table"] == "notes"
|
||||
assert data["data"]["id"] == "n1"
|
||||
|
||||
@@ -243,14 +243,14 @@ class TestPlanCache:
|
||||
|
||||
class TestModuleSingletons:
|
||||
def test_template_registry_has_all_agent_defaults(self) -> None:
|
||||
for agent in ("task_agent", "calendar_agent", "email_agent", "analytics_agent"):
|
||||
for agent in ("task_agent", "checkpoint_agent", "project_agent", "note_agent"):
|
||||
assert template_registry.has(f"tpl_{agent}_default"), (
|
||||
f"Missing template: tpl_{agent}_default"
|
||||
)
|
||||
|
||||
def test_template_registry_has_operation_templates(self) -> None:
|
||||
assert template_registry.has("tpl_email_extract_action_items")
|
||||
assert template_registry.has("tpl_analytics_weekly_summary")
|
||||
assert template_registry.has("tpl_task_extract_from_project")
|
||||
assert template_registry.has("tpl_note_weekly_summary")
|
||||
|
||||
def test_template_registry_get_returns_non_empty_string(self) -> None:
|
||||
text = template_registry.get("tpl_task_agent_default")
|
||||
@@ -260,20 +260,20 @@ class TestModuleSingletons:
|
||||
def test_plan_cache_has_prebuilt_playbooks(self) -> None:
|
||||
assert len(plan_cache.get_all_playbooks()) >= 2
|
||||
|
||||
def test_playbook_create_task_from_email(self) -> None:
|
||||
plan = plan_cache.get_plan("create_task_from_email")
|
||||
def test_playbook_create_tasks_from_project(self) -> None:
|
||||
plan = plan_cache.get_plan("create_tasks_from_project")
|
||||
assert plan is not None
|
||||
assert plan.agent == "email_agent"
|
||||
assert plan.agent == "project_agent"
|
||||
assert len(plan.steps) == 2
|
||||
assert plan.steps[0].prompt_template == "tpl_email_extract_action_items"
|
||||
assert plan.steps[0].prompt_template == "tpl_task_extract_from_project"
|
||||
assert plan.steps[1].data_from_step == 0
|
||||
|
||||
def test_playbook_generate_weekly_report(self) -> None:
|
||||
plan = plan_cache.get_plan("generate_weekly_report")
|
||||
def test_playbook_generate_weekly_note(self) -> None:
|
||||
plan = plan_cache.get_plan("generate_weekly_note")
|
||||
assert plan is not None
|
||||
assert plan.agent == "analytics_agent"
|
||||
assert plan.agent == "note_agent"
|
||||
assert len(plan.steps) == 2
|
||||
assert plan.steps[0].prompt_template == "tpl_analytics_weekly_summary"
|
||||
assert plan.steps[0].prompt_template == "tpl_note_weekly_summary"
|
||||
assert plan.steps[1].data_from_step == 0
|
||||
|
||||
def test_playbook_steps_have_no_raw_prompt_text(self) -> None:
|
||||
|
||||
385
tests/test_storage.py
Normal file
385
tests/test_storage.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""Tests for the storage layer: encryption, BlobStore, and VectorStore."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from botocore.exceptions import ClientError
|
||||
from moto import mock_aws
|
||||
|
||||
from app.storage.encryption import reject_if_tampered, verify_checksum
|
||||
from app.storage.blob_store import BlobStore
|
||||
from app.storage.vector_store import VectorStore, _blob_to_vector
|
||||
from app.schemas import VectorItem, VectorSearchResult
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
_BLOB = b"encrypted-payload-opaque-to-server"
|
||||
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
|
||||
_BUCKET = "test-bucket"
|
||||
_REGION = "us-east-1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_bucket():
|
||||
"""Create a mocked S3 bucket and expose its name."""
|
||||
with mock_aws():
|
||||
os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing")
|
||||
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing")
|
||||
os.environ.setdefault("AWS_DEFAULT_REGION", _REGION)
|
||||
client = boto3.client("s3", region_name=_REGION)
|
||||
client.create_bucket(Bucket=_BUCKET)
|
||||
with patch("app.storage.blob_store.settings") as mock_settings:
|
||||
mock_settings.S3_BUCKET = _BUCKET
|
||||
mock_settings.S3_REGION = _REGION
|
||||
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
||||
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
||||
yield _BUCKET
|
||||
|
||||
|
||||
def _pinecone_mock():
|
||||
"""Return a mock Pinecone index with realistic return shapes."""
|
||||
mock_index = MagicMock()
|
||||
mock_index.query.return_value = {
|
||||
"matches": [
|
||||
{
|
||||
"id": "v1",
|
||||
"score": 0.95,
|
||||
"metadata": {
|
||||
"blob": base64.b64encode(b"result-blob").decode(),
|
||||
"checksum": hashlib.sha256(b"result-blob").hexdigest(),
|
||||
"user_id": "u1",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_pc = MagicMock()
|
||||
mock_pc.return_value.Index.return_value = mock_index
|
||||
return mock_pc, mock_index
|
||||
|
||||
|
||||
# ── TestEncryption ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEncryption:
|
||||
def test_verify_checksum_correct(self) -> None:
|
||||
assert verify_checksum(_BLOB, _CHECKSUM) is True
|
||||
|
||||
def test_verify_checksum_wrong(self) -> None:
|
||||
assert verify_checksum(_BLOB, "0" * 64) is False
|
||||
|
||||
def test_verify_checksum_empty_checksum(self) -> None:
|
||||
assert verify_checksum(_BLOB, "") is False
|
||||
|
||||
def test_verify_checksum_empty_blob(self) -> None:
|
||||
expected = hashlib.sha256(b"").hexdigest()
|
||||
assert verify_checksum(b"", expected) is True
|
||||
|
||||
def test_verify_checksum_tampered_blob(self) -> None:
|
||||
tampered = _BLOB + b"\x00"
|
||||
assert verify_checksum(tampered, _CHECKSUM) is False
|
||||
|
||||
def test_reject_if_tampered_passes_when_valid(self) -> None:
|
||||
# Should not raise
|
||||
reject_if_tampered(_BLOB, _CHECKSUM)
|
||||
|
||||
def test_reject_if_tampered_raises_400_on_mismatch(self) -> None:
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
reject_if_tampered(_BLOB, "bad" * 20)
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
def test_reject_if_tampered_detail_mentions_checksum(self) -> None:
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
reject_if_tampered(_BLOB, "bad" * 20)
|
||||
assert "checksum" in exc_info.value.detail.lower()
|
||||
|
||||
def test_checksum_is_sha256_hex(self) -> None:
|
||||
cs = hashlib.sha256(_BLOB).hexdigest()
|
||||
assert len(cs) == 64
|
||||
assert all(c in "0123456789abcdef" for c in cs)
|
||||
|
||||
|
||||
# ── TestBlobStore ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBlobStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_returns_correct_key(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
key = await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
assert key == "u1/tasks/r1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_object_exists_in_s3(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
# Verify by downloading — no exception means object exists
|
||||
retrieved = await store.download("u1", "u1/tasks/r1")
|
||||
assert retrieved == _BLOB
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_retrieves_same_bytes(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "notes", "n1", b"note-data", hashlib.sha256(b"note-data").hexdigest())
|
||||
result = await store.download("u1", "u1/notes/n1")
|
||||
assert result == b"note-data"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_removes_object(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
await store.delete("u1", "u1/tasks/r1")
|
||||
with pytest.raises(ClientError) as exc_info:
|
||||
await store.download("u1", "u1/tasks/r1")
|
||||
assert exc_info.value.response["Error"]["Code"] == "NoSuchKey"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_is_idempotent(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
# Delete a key that never existed — should not raise
|
||||
await store.delete("u1", "u1/tasks/nonexistent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_keys_returns_correct_keys(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
await store.upload("u1", "tasks", "r2", _BLOB, _CHECKSUM)
|
||||
keys = await store.list_keys("u1", "tasks")
|
||||
assert set(keys) == {"u1/tasks/r1", "u1/tasks/r2"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_keys_scoped_to_table(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
await store.upload("u1", "notes", "n1", _BLOB, _CHECKSUM)
|
||||
keys = await store.list_keys("u1", "tasks")
|
||||
assert "u1/notes/n1" not in keys
|
||||
assert "u1/tasks/r1" in keys
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_keys_no_cross_user_leakage(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
await store.upload("u2", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
keys_u1 = await store.list_keys("u1", "tasks")
|
||||
assert "u2/tasks/r1" not in keys_u1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_keys_empty_table(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
keys = await store.list_keys("u1", "tasks")
|
||||
assert keys == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_uses_sse_s3_encryption(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
# Verify S3 metadata was set — check via head_object
|
||||
with patch("app.storage.blob_store.settings") as mock_settings:
|
||||
mock_settings.S3_BUCKET = _BUCKET
|
||||
mock_settings.S3_REGION = _REGION
|
||||
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
||||
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
||||
client = boto3.client("s3", region_name=_REGION)
|
||||
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
|
||||
assert response.get("ServerSideEncryption") == "AES256"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_stores_checksum_in_metadata(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
client = boto3.client("s3", region_name=_REGION)
|
||||
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
|
||||
assert response["Metadata"]["checksum"] == _CHECKSUM
|
||||
|
||||
|
||||
# ── _blob_to_vector helper ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBlobToVector:
|
||||
def test_returns_32_floats(self) -> None:
|
||||
v = _blob_to_vector(b"test")
|
||||
assert len(v) == 32
|
||||
|
||||
def test_all_values_in_range(self) -> None:
|
||||
v = _blob_to_vector(b"test")
|
||||
assert all(-1.0 <= x <= 1.0 for x in v)
|
||||
|
||||
def test_deterministic(self) -> None:
|
||||
assert _blob_to_vector(b"same") == _blob_to_vector(b"same")
|
||||
|
||||
def test_different_blobs_different_vectors(self) -> None:
|
||||
assert _blob_to_vector(b"aaa") != _blob_to_vector(b"bbb")
|
||||
|
||||
|
||||
# ── TestVectorStorePinecone ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestVectorStorePinecone:
|
||||
def _store(self) -> VectorStore:
|
||||
store = VectorStore()
|
||||
store._use_pinecone = lambda: True # type: ignore[method-assign]
|
||||
return store
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_calls_index_upsert(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
items = [VectorItem(id="v1", blob=b"enc-blob", checksum=hashlib.sha256(b"enc-blob").hexdigest())]
|
||||
await store.upsert("u1", items)
|
||||
mock_index.upsert.assert_called_once()
|
||||
call_kwargs = mock_index.upsert.call_args[1]
|
||||
assert call_kwargs.get("namespace") == "u1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_encodes_blob_as_base64_in_metadata(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
items = [VectorItem(id="v1", blob=b"secret", checksum=hashlib.sha256(b"secret").hexdigest())]
|
||||
await store.upsert("u1", items)
|
||||
vectors_arg = mock_index.upsert.call_args[1]["vectors"]
|
||||
assert vectors_arg[0]["metadata"]["blob"] == base64.b64encode(b"secret").decode()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_calls_index_query(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
await store.search("u1", b"query-blob", top_k=5)
|
||||
mock_index.query.assert_called_once()
|
||||
query_kwargs = mock_index.query.call_args[1]
|
||||
assert query_kwargs.get("namespace") == "u1"
|
||||
assert query_kwargs.get("top_k") == 5
|
||||
assert query_kwargs.get("include_metadata") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_returns_vector_search_results(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
results = await store.search("u1", b"query", top_k=10)
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], VectorSearchResult)
|
||||
assert results[0].id == "v1"
|
||||
assert results[0].score == 0.95
|
||||
assert results[0].blob == b"result-blob"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_uses_derived_query_vector(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
await store.search("u1", b"query-blob", top_k=3)
|
||||
expected_vector = _blob_to_vector(b"query-blob")
|
||||
actual_vector = mock_index.query.call_args[1].get("vector")
|
||||
assert actual_vector == expected_vector
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_calls_index_delete(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
await store.delete("u1", ["v1", "v2"])
|
||||
mock_index.delete.assert_called_once()
|
||||
delete_kwargs = mock_index.delete.call_args[1]
|
||||
assert delete_kwargs.get("namespace") == "u1"
|
||||
assert set(delete_kwargs.get("ids", [])) == {"v1", "v2"}
|
||||
|
||||
|
||||
# ── TestVectorStoreQdrant ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestVectorStoreQdrant:
|
||||
def _store(self) -> VectorStore:
|
||||
store = VectorStore()
|
||||
store._use_pinecone = lambda: False # type: ignore[method-assign]
|
||||
return store
|
||||
|
||||
def _qdrant_mock(self) -> MagicMock:
|
||||
mock_hit = MagicMock()
|
||||
mock_hit.id = "v1"
|
||||
mock_hit.score = 0.88
|
||||
mock_hit.payload = {
|
||||
"blob": base64.b64encode(b"qdrant-result").decode(),
|
||||
"user_id": "u1",
|
||||
}
|
||||
mock_client = MagicMock()
|
||||
mock_client.search.return_value = [mock_hit]
|
||||
return mock_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_calls_client_upsert(self) -> None:
|
||||
mock_client = MagicMock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
|
||||
await store.upsert("u1", items)
|
||||
mock_client.upsert.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_uses_correct_collection(self) -> None:
|
||||
mock_client = MagicMock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
|
||||
await store.upsert("u1", items)
|
||||
call_kwargs = mock_client.upsert.call_args[1]
|
||||
assert call_kwargs["collection_name"] == "adiuva_vectors"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_calls_client_search(self) -> None:
|
||||
mock_client = self._qdrant_mock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
await store.search("u1", b"query", top_k=5)
|
||||
mock_client.search.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_passes_limit(self) -> None:
|
||||
mock_client = self._qdrant_mock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
await store.search("u1", b"query", top_k=7)
|
||||
call_kwargs = mock_client.search.call_args[1]
|
||||
assert call_kwargs.get("limit") == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_returns_vector_search_results(self) -> None:
|
||||
mock_client = self._qdrant_mock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
results = await store.search("u1", b"query", top_k=5)
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], VectorSearchResult)
|
||||
assert results[0].id == "v1"
|
||||
assert results[0].score == 0.88
|
||||
assert results[0].blob == b"qdrant-result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_calls_client_delete(self) -> None:
|
||||
mock_client = MagicMock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
await store.delete("u1", ["v1", "v2"])
|
||||
mock_client.delete.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_uses_correct_collection(self) -> None:
|
||||
mock_client = MagicMock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
await store.delete("u1", ["v1"])
|
||||
call_kwargs = mock_client.delete.call_args[1]
|
||||
assert call_kwargs["collection_name"] == "adiuva_vectors"
|
||||
Reference in New Issue
Block a user