- Updated `TestModuleSingletons` in `test_execution_plan.py` to reflect new agent templates and playbook names. - Changed assertions in playbook tests to match updated templates and agents. - Introduced `test_storage.py` to cover the storage layer, including encryption, BlobStore, and VectorStore functionalities. - Added tests for S3 interactions, ensuring upload, download, delete, and list operations work as expected. - Implemented mock tests for Pinecone and Qdrant vector stores to validate upsert, search, and delete operations.
386 lines
16 KiB
Python
386 lines
16 KiB
Python
"""Tests for the storage layer: encryption, BlobStore, and VectorStore."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import hashlib
|
|
import os
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import boto3
|
|
import pytest
|
|
from botocore.exceptions import ClientError
|
|
from moto import mock_aws
|
|
|
|
from app.storage.encryption import reject_if_tampered, verify_checksum
|
|
from app.storage.blob_store import BlobStore
|
|
from app.storage.vector_store import VectorStore, _blob_to_vector
|
|
from app.schemas import VectorItem, VectorSearchResult
|
|
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────
|
|
|
|
_BLOB = b"encrypted-payload-opaque-to-server"
|
|
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
|
|
_BUCKET = "test-bucket"
|
|
_REGION = "us-east-1"
|
|
|
|
|
|
@pytest.fixture
|
|
def s3_bucket():
|
|
"""Create a mocked S3 bucket and expose its name."""
|
|
with mock_aws():
|
|
os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing")
|
|
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing")
|
|
os.environ.setdefault("AWS_DEFAULT_REGION", _REGION)
|
|
client = boto3.client("s3", region_name=_REGION)
|
|
client.create_bucket(Bucket=_BUCKET)
|
|
with patch("app.storage.blob_store.settings") as mock_settings:
|
|
mock_settings.S3_BUCKET = _BUCKET
|
|
mock_settings.S3_REGION = _REGION
|
|
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
|
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
|
yield _BUCKET
|
|
|
|
|
|
def _pinecone_mock():
|
|
"""Return a mock Pinecone index with realistic return shapes."""
|
|
mock_index = MagicMock()
|
|
mock_index.query.return_value = {
|
|
"matches": [
|
|
{
|
|
"id": "v1",
|
|
"score": 0.95,
|
|
"metadata": {
|
|
"blob": base64.b64encode(b"result-blob").decode(),
|
|
"checksum": hashlib.sha256(b"result-blob").hexdigest(),
|
|
"user_id": "u1",
|
|
},
|
|
}
|
|
]
|
|
}
|
|
mock_pc = MagicMock()
|
|
mock_pc.return_value.Index.return_value = mock_index
|
|
return mock_pc, mock_index
|
|
|
|
|
|
# ── TestEncryption ────────────────────────────────────────────────────
|
|
|
|
|
|
class TestEncryption:
|
|
def test_verify_checksum_correct(self) -> None:
|
|
assert verify_checksum(_BLOB, _CHECKSUM) is True
|
|
|
|
def test_verify_checksum_wrong(self) -> None:
|
|
assert verify_checksum(_BLOB, "0" * 64) is False
|
|
|
|
def test_verify_checksum_empty_checksum(self) -> None:
|
|
assert verify_checksum(_BLOB, "") is False
|
|
|
|
def test_verify_checksum_empty_blob(self) -> None:
|
|
expected = hashlib.sha256(b"").hexdigest()
|
|
assert verify_checksum(b"", expected) is True
|
|
|
|
def test_verify_checksum_tampered_blob(self) -> None:
|
|
tampered = _BLOB + b"\x00"
|
|
assert verify_checksum(tampered, _CHECKSUM) is False
|
|
|
|
def test_reject_if_tampered_passes_when_valid(self) -> None:
|
|
# Should not raise
|
|
reject_if_tampered(_BLOB, _CHECKSUM)
|
|
|
|
def test_reject_if_tampered_raises_400_on_mismatch(self) -> None:
|
|
from fastapi import HTTPException
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
reject_if_tampered(_BLOB, "bad" * 20)
|
|
assert exc_info.value.status_code == 400
|
|
|
|
def test_reject_if_tampered_detail_mentions_checksum(self) -> None:
|
|
from fastapi import HTTPException
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
reject_if_tampered(_BLOB, "bad" * 20)
|
|
assert "checksum" in exc_info.value.detail.lower()
|
|
|
|
def test_checksum_is_sha256_hex(self) -> None:
|
|
cs = hashlib.sha256(_BLOB).hexdigest()
|
|
assert len(cs) == 64
|
|
assert all(c in "0123456789abcdef" for c in cs)
|
|
|
|
|
|
# ── TestBlobStore ─────────────────────────────────────────────────────
|
|
|
|
|
|
class TestBlobStore:
|
|
@pytest.mark.asyncio
|
|
async def test_upload_returns_correct_key(self, s3_bucket: str) -> None:
|
|
store = BlobStore()
|
|
key = await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
|
assert key == "u1/tasks/r1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upload_object_exists_in_s3(self, s3_bucket: str) -> None:
|
|
store = BlobStore()
|
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
|
# Verify by downloading — no exception means object exists
|
|
retrieved = await store.download("u1", "u1/tasks/r1")
|
|
assert retrieved == _BLOB
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_download_retrieves_same_bytes(self, s3_bucket: str) -> None:
|
|
store = BlobStore()
|
|
await store.upload("u1", "notes", "n1", b"note-data", hashlib.sha256(b"note-data").hexdigest())
|
|
result = await store.download("u1", "u1/notes/n1")
|
|
assert result == b"note-data"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_removes_object(self, s3_bucket: str) -> None:
|
|
store = BlobStore()
|
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
|
await store.delete("u1", "u1/tasks/r1")
|
|
with pytest.raises(ClientError) as exc_info:
|
|
await store.download("u1", "u1/tasks/r1")
|
|
assert exc_info.value.response["Error"]["Code"] == "NoSuchKey"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_is_idempotent(self, s3_bucket: str) -> None:
|
|
store = BlobStore()
|
|
# Delete a key that never existed — should not raise
|
|
await store.delete("u1", "u1/tasks/nonexistent")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_keys_returns_correct_keys(self, s3_bucket: str) -> None:
|
|
store = BlobStore()
|
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
|
await store.upload("u1", "tasks", "r2", _BLOB, _CHECKSUM)
|
|
keys = await store.list_keys("u1", "tasks")
|
|
assert set(keys) == {"u1/tasks/r1", "u1/tasks/r2"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_keys_scoped_to_table(self, s3_bucket: str) -> None:
|
|
store = BlobStore()
|
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
|
await store.upload("u1", "notes", "n1", _BLOB, _CHECKSUM)
|
|
keys = await store.list_keys("u1", "tasks")
|
|
assert "u1/notes/n1" not in keys
|
|
assert "u1/tasks/r1" in keys
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_keys_no_cross_user_leakage(self, s3_bucket: str) -> None:
|
|
store = BlobStore()
|
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
|
await store.upload("u2", "tasks", "r1", _BLOB, _CHECKSUM)
|
|
keys_u1 = await store.list_keys("u1", "tasks")
|
|
assert "u2/tasks/r1" not in keys_u1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_keys_empty_table(self, s3_bucket: str) -> None:
|
|
store = BlobStore()
|
|
keys = await store.list_keys("u1", "tasks")
|
|
assert keys == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upload_uses_sse_s3_encryption(self, s3_bucket: str) -> None:
|
|
store = BlobStore()
|
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
|
# Verify S3 metadata was set — check via head_object
|
|
with patch("app.storage.blob_store.settings") as mock_settings:
|
|
mock_settings.S3_BUCKET = _BUCKET
|
|
mock_settings.S3_REGION = _REGION
|
|
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
|
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
|
client = boto3.client("s3", region_name=_REGION)
|
|
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
|
|
assert response.get("ServerSideEncryption") == "AES256"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upload_stores_checksum_in_metadata(self, s3_bucket: str) -> None:
|
|
store = BlobStore()
|
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
|
client = boto3.client("s3", region_name=_REGION)
|
|
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
|
|
assert response["Metadata"]["checksum"] == _CHECKSUM
|
|
|
|
|
|
# ── _blob_to_vector helper ────────────────────────────────────────────
|
|
|
|
|
|
class TestBlobToVector:
|
|
def test_returns_32_floats(self) -> None:
|
|
v = _blob_to_vector(b"test")
|
|
assert len(v) == 32
|
|
|
|
def test_all_values_in_range(self) -> None:
|
|
v = _blob_to_vector(b"test")
|
|
assert all(-1.0 <= x <= 1.0 for x in v)
|
|
|
|
def test_deterministic(self) -> None:
|
|
assert _blob_to_vector(b"same") == _blob_to_vector(b"same")
|
|
|
|
def test_different_blobs_different_vectors(self) -> None:
|
|
assert _blob_to_vector(b"aaa") != _blob_to_vector(b"bbb")
|
|
|
|
|
|
# ── TestVectorStorePinecone ───────────────────────────────────────────
|
|
|
|
|
|
class TestVectorStorePinecone:
|
|
def _store(self) -> VectorStore:
|
|
store = VectorStore()
|
|
store._use_pinecone = lambda: True # type: ignore[method-assign]
|
|
return store
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upsert_calls_index_upsert(self) -> None:
|
|
mock_pc, mock_index = _pinecone_mock()
|
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
|
store = self._store()
|
|
items = [VectorItem(id="v1", blob=b"enc-blob", checksum=hashlib.sha256(b"enc-blob").hexdigest())]
|
|
await store.upsert("u1", items)
|
|
mock_index.upsert.assert_called_once()
|
|
call_kwargs = mock_index.upsert.call_args[1]
|
|
assert call_kwargs.get("namespace") == "u1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upsert_encodes_blob_as_base64_in_metadata(self) -> None:
|
|
mock_pc, mock_index = _pinecone_mock()
|
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
|
store = self._store()
|
|
items = [VectorItem(id="v1", blob=b"secret", checksum=hashlib.sha256(b"secret").hexdigest())]
|
|
await store.upsert("u1", items)
|
|
vectors_arg = mock_index.upsert.call_args[1]["vectors"]
|
|
assert vectors_arg[0]["metadata"]["blob"] == base64.b64encode(b"secret").decode()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_calls_index_query(self) -> None:
|
|
mock_pc, mock_index = _pinecone_mock()
|
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
|
store = self._store()
|
|
await store.search("u1", b"query-blob", top_k=5)
|
|
mock_index.query.assert_called_once()
|
|
query_kwargs = mock_index.query.call_args[1]
|
|
assert query_kwargs.get("namespace") == "u1"
|
|
assert query_kwargs.get("top_k") == 5
|
|
assert query_kwargs.get("include_metadata") is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_returns_vector_search_results(self) -> None:
|
|
mock_pc, mock_index = _pinecone_mock()
|
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
|
store = self._store()
|
|
results = await store.search("u1", b"query", top_k=10)
|
|
assert len(results) == 1
|
|
assert isinstance(results[0], VectorSearchResult)
|
|
assert results[0].id == "v1"
|
|
assert results[0].score == 0.95
|
|
assert results[0].blob == b"result-blob"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_uses_derived_query_vector(self) -> None:
|
|
mock_pc, mock_index = _pinecone_mock()
|
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
|
store = self._store()
|
|
await store.search("u1", b"query-blob", top_k=3)
|
|
expected_vector = _blob_to_vector(b"query-blob")
|
|
actual_vector = mock_index.query.call_args[1].get("vector")
|
|
assert actual_vector == expected_vector
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_calls_index_delete(self) -> None:
|
|
mock_pc, mock_index = _pinecone_mock()
|
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
|
store = self._store()
|
|
await store.delete("u1", ["v1", "v2"])
|
|
mock_index.delete.assert_called_once()
|
|
delete_kwargs = mock_index.delete.call_args[1]
|
|
assert delete_kwargs.get("namespace") == "u1"
|
|
assert set(delete_kwargs.get("ids", [])) == {"v1", "v2"}
|
|
|
|
|
|
# ── TestVectorStoreQdrant ─────────────────────────────────────────────
|
|
|
|
|
|
class TestVectorStoreQdrant:
|
|
def _store(self) -> VectorStore:
|
|
store = VectorStore()
|
|
store._use_pinecone = lambda: False # type: ignore[method-assign]
|
|
return store
|
|
|
|
def _qdrant_mock(self) -> MagicMock:
|
|
mock_hit = MagicMock()
|
|
mock_hit.id = "v1"
|
|
mock_hit.score = 0.88
|
|
mock_hit.payload = {
|
|
"blob": base64.b64encode(b"qdrant-result").decode(),
|
|
"user_id": "u1",
|
|
}
|
|
mock_client = MagicMock()
|
|
mock_client.search.return_value = [mock_hit]
|
|
return mock_client
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upsert_calls_client_upsert(self) -> None:
|
|
mock_client = MagicMock()
|
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
|
store = self._store()
|
|
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
|
|
await store.upsert("u1", items)
|
|
mock_client.upsert.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_upsert_uses_correct_collection(self) -> None:
|
|
mock_client = MagicMock()
|
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
|
store = self._store()
|
|
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
|
|
await store.upsert("u1", items)
|
|
call_kwargs = mock_client.upsert.call_args[1]
|
|
assert call_kwargs["collection_name"] == "adiuva_vectors"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_calls_client_search(self) -> None:
|
|
mock_client = self._qdrant_mock()
|
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
|
store = self._store()
|
|
await store.search("u1", b"query", top_k=5)
|
|
mock_client.search.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_passes_limit(self) -> None:
|
|
mock_client = self._qdrant_mock()
|
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
|
store = self._store()
|
|
await store.search("u1", b"query", top_k=7)
|
|
call_kwargs = mock_client.search.call_args[1]
|
|
assert call_kwargs.get("limit") == 7
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_search_returns_vector_search_results(self) -> None:
|
|
mock_client = self._qdrant_mock()
|
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
|
store = self._store()
|
|
results = await store.search("u1", b"query", top_k=5)
|
|
assert len(results) == 1
|
|
assert isinstance(results[0], VectorSearchResult)
|
|
assert results[0].id == "v1"
|
|
assert results[0].score == 0.88
|
|
assert results[0].blob == b"qdrant-result"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_calls_client_delete(self) -> None:
|
|
mock_client = MagicMock()
|
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
|
store = self._store()
|
|
await store.delete("u1", ["v1", "v2"])
|
|
mock_client.delete.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_uses_correct_collection(self) -> None:
|
|
mock_client = MagicMock()
|
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
|
store = self._store()
|
|
await store.delete("u1", ["v1"])
|
|
call_kwargs = mock_client.delete.call_args[1]
|
|
assert call_kwargs["collection_name"] == "adiuva_vectors"
|