"""Tests for the storage layer: encryption, BlobStore, VectorStore, and storage routes.""" from __future__ import annotations import base64 import hashlib from unittest.mock import MagicMock, patch import boto3 import pytest from botocore.exceptions import ClientError from app.storage.encryption import reject_if_tampered, verify_checksum from app.storage.blob_store import BlobStore from app.storage.vector_store import VectorStore, _blob_to_vector from app.schemas import VectorItem, VectorSearchResult from tests.conftest import auth_header, S3_TEST_BUCKET # ── Helpers ─────────────────────────────────────────────────────────── _BLOB = b"encrypted-payload-opaque-to-server" _CHECKSUM = hashlib.sha256(_BLOB).hexdigest() _BUCKET = S3_TEST_BUCKET _REGION = "us-east-1" def _pinecone_mock(): """Return a mock Pinecone index with realistic return shapes.""" mock_index = MagicMock() mock_index.query.return_value = { "matches": [ { "id": "v1", "score": 0.95, "metadata": { "blob": base64.b64encode(b"result-blob").decode(), "checksum": hashlib.sha256(b"result-blob").hexdigest(), "user_id": "u1", }, } ] } mock_pc = MagicMock() mock_pc.return_value.Index.return_value = mock_index return mock_pc, mock_index # ── TestEncryption ──────────────────────────────────────────────────── class TestEncryption: def test_verify_checksum_correct(self) -> None: assert verify_checksum(_BLOB, _CHECKSUM) is True def test_verify_checksum_wrong(self) -> None: assert verify_checksum(_BLOB, "0" * 64) is False def test_verify_checksum_empty_checksum(self) -> None: assert verify_checksum(_BLOB, "") is False def test_verify_checksum_empty_blob(self) -> None: expected = hashlib.sha256(b"").hexdigest() assert verify_checksum(b"", expected) is True def test_verify_checksum_tampered_blob(self) -> None: tampered = _BLOB + b"\x00" assert verify_checksum(tampered, _CHECKSUM) is False def test_reject_if_tampered_passes_when_valid(self) -> None: # Should not raise reject_if_tampered(_BLOB, _CHECKSUM) def test_reject_if_tampered_raises_400_on_mismatch(self) -> None: from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: reject_if_tampered(_BLOB, "bad" * 20) assert exc_info.value.status_code == 400 def test_reject_if_tampered_detail_mentions_checksum(self) -> None: from fastapi import HTTPException with pytest.raises(HTTPException) as exc_info: reject_if_tampered(_BLOB, "bad" * 20) assert "checksum" in exc_info.value.detail.lower() def test_checksum_is_sha256_hex(self) -> None: cs = hashlib.sha256(_BLOB).hexdigest() assert len(cs) == 64 assert all(c in "0123456789abcdef" for c in cs) # ── TestBlobStore ───────────────────────────────────────────────────── class TestBlobStore: @pytest.mark.asyncio async def test_upload_returns_correct_key(self, s3_bucket: str) -> None: store = BlobStore() key = await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM) assert key == "u1/tasks/r1" @pytest.mark.asyncio async def test_upload_object_exists_in_s3(self, s3_bucket: str) -> None: store = BlobStore() await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM) # Verify by downloading — no exception means object exists retrieved = await store.download("u1", "u1/tasks/r1") assert retrieved == _BLOB @pytest.mark.asyncio async def test_download_retrieves_same_bytes(self, s3_bucket: str) -> None: store = BlobStore() await store.upload("u1", "notes", "n1", b"note-data", hashlib.sha256(b"note-data").hexdigest()) result = await store.download("u1", "u1/notes/n1") assert result == b"note-data" @pytest.mark.asyncio async def test_delete_removes_object(self, s3_bucket: str) -> None: store = BlobStore() await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM) await store.delete("u1", "u1/tasks/r1") with pytest.raises(ClientError) as exc_info: await store.download("u1", "u1/tasks/r1") assert exc_info.value.response["Error"]["Code"] == "NoSuchKey" @pytest.mark.asyncio async def test_delete_is_idempotent(self, s3_bucket: str) -> None: store = BlobStore() # Delete a key that never existed — should not raise await store.delete("u1", "u1/tasks/nonexistent") @pytest.mark.asyncio async def test_list_keys_returns_correct_keys(self, s3_bucket: str) -> None: store = BlobStore() await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM) await store.upload("u1", "tasks", "r2", _BLOB, _CHECKSUM) keys = await store.list_keys("u1", "tasks") assert set(keys) == {"u1/tasks/r1", "u1/tasks/r2"} @pytest.mark.asyncio async def test_list_keys_scoped_to_table(self, s3_bucket: str) -> None: store = BlobStore() await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM) await store.upload("u1", "notes", "n1", _BLOB, _CHECKSUM) keys = await store.list_keys("u1", "tasks") assert "u1/notes/n1" not in keys assert "u1/tasks/r1" in keys @pytest.mark.asyncio async def test_list_keys_no_cross_user_leakage(self, s3_bucket: str) -> None: store = BlobStore() await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM) await store.upload("u2", "tasks", "r1", _BLOB, _CHECKSUM) keys_u1 = await store.list_keys("u1", "tasks") assert "u2/tasks/r1" not in keys_u1 @pytest.mark.asyncio async def test_list_keys_empty_table(self, s3_bucket: str) -> None: store = BlobStore() keys = await store.list_keys("u1", "tasks") assert keys == [] @pytest.mark.asyncio async def test_upload_uses_sse_s3_encryption(self, s3_bucket: str) -> None: store = BlobStore() await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM) # Verify S3 metadata was set — check via head_object with patch("app.storage.blob_store.settings") as mock_settings: mock_settings.S3_BUCKET = _BUCKET mock_settings.S3_REGION = _REGION mock_settings.AWS_ACCESS_KEY_ID = "testing" mock_settings.AWS_SECRET_ACCESS_KEY = "testing" client = boto3.client("s3", region_name=_REGION) response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1") assert response.get("ServerSideEncryption") == "AES256" @pytest.mark.asyncio async def test_upload_stores_checksum_in_metadata(self, s3_bucket: str) -> None: store = BlobStore() await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM) client = boto3.client("s3", region_name=_REGION) response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1") assert response["Metadata"]["checksum"] == _CHECKSUM # ── _blob_to_vector helper ──────────────────────────────────────────── class TestBlobToVector: def test_returns_32_floats(self) -> None: v = _blob_to_vector(b"test") assert len(v) == 32 def test_all_values_in_range(self) -> None: v = _blob_to_vector(b"test") assert all(-1.0 <= x <= 1.0 for x in v) def test_deterministic(self) -> None: assert _blob_to_vector(b"same") == _blob_to_vector(b"same") def test_different_blobs_different_vectors(self) -> None: assert _blob_to_vector(b"aaa") != _blob_to_vector(b"bbb") # ── TestVectorStorePinecone ─────────────────────────────────────────── class TestVectorStorePinecone: def _store(self) -> VectorStore: store = VectorStore() store._use_pinecone = lambda: True # type: ignore[method-assign] return store @pytest.mark.asyncio async def test_upsert_calls_index_upsert(self) -> None: mock_pc, mock_index = _pinecone_mock() with patch("app.storage.vector_store.Pinecone", mock_pc): store = self._store() items = [VectorItem(id="v1", blob=b"enc-blob", checksum=hashlib.sha256(b"enc-blob").hexdigest())] await store.upsert("u1", items) mock_index.upsert.assert_called_once() call_kwargs = mock_index.upsert.call_args[1] assert call_kwargs.get("namespace") == "u1" @pytest.mark.asyncio async def test_upsert_encodes_blob_as_base64_in_metadata(self) -> None: mock_pc, mock_index = _pinecone_mock() with patch("app.storage.vector_store.Pinecone", mock_pc): store = self._store() items = [VectorItem(id="v1", blob=b"secret", checksum=hashlib.sha256(b"secret").hexdigest())] await store.upsert("u1", items) vectors_arg = mock_index.upsert.call_args[1]["vectors"] assert vectors_arg[0]["metadata"]["blob"] == base64.b64encode(b"secret").decode() @pytest.mark.asyncio async def test_search_calls_index_query(self) -> None: mock_pc, mock_index = _pinecone_mock() with patch("app.storage.vector_store.Pinecone", mock_pc): store = self._store() await store.search("u1", b"query-blob", top_k=5) mock_index.query.assert_called_once() query_kwargs = mock_index.query.call_args[1] assert query_kwargs.get("namespace") == "u1" assert query_kwargs.get("top_k") == 5 assert query_kwargs.get("include_metadata") is True @pytest.mark.asyncio async def test_search_returns_vector_search_results(self) -> None: mock_pc, mock_index = _pinecone_mock() with patch("app.storage.vector_store.Pinecone", mock_pc): store = self._store() results = await store.search("u1", b"query", top_k=10) assert len(results) == 1 assert isinstance(results[0], VectorSearchResult) assert results[0].id == "v1" assert results[0].score == 0.95 assert results[0].blob == b"result-blob" @pytest.mark.asyncio async def test_search_uses_derived_query_vector(self) -> None: mock_pc, mock_index = _pinecone_mock() with patch("app.storage.vector_store.Pinecone", mock_pc): store = self._store() await store.search("u1", b"query-blob", top_k=3) expected_vector = _blob_to_vector(b"query-blob") actual_vector = mock_index.query.call_args[1].get("vector") assert actual_vector == expected_vector @pytest.mark.asyncio async def test_delete_calls_index_delete(self) -> None: mock_pc, mock_index = _pinecone_mock() with patch("app.storage.vector_store.Pinecone", mock_pc): store = self._store() await store.delete("u1", ["v1", "v2"]) mock_index.delete.assert_called_once() delete_kwargs = mock_index.delete.call_args[1] assert delete_kwargs.get("namespace") == "u1" assert set(delete_kwargs.get("ids", [])) == {"v1", "v2"} # ── TestVectorStoreQdrant ───────────────────────────────────────────── class TestVectorStoreQdrant: def _store(self) -> VectorStore: store = VectorStore() store._use_pinecone = lambda: False # type: ignore[method-assign] return store def _qdrant_mock(self) -> MagicMock: mock_hit = MagicMock() mock_hit.id = "v1" mock_hit.score = 0.88 mock_hit.payload = { "blob": base64.b64encode(b"qdrant-result").decode(), "user_id": "u1", } mock_client = MagicMock() mock_client.search.return_value = [mock_hit] return mock_client @pytest.mark.asyncio async def test_upsert_calls_client_upsert(self) -> None: mock_client = MagicMock() with patch("app.storage.vector_store.QdrantClient", return_value=mock_client): store = self._store() items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())] await store.upsert("u1", items) mock_client.upsert.assert_called_once() @pytest.mark.asyncio async def test_upsert_uses_correct_collection(self) -> None: mock_client = MagicMock() with patch("app.storage.vector_store.QdrantClient", return_value=mock_client): store = self._store() items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())] await store.upsert("u1", items) call_kwargs = mock_client.upsert.call_args[1] assert call_kwargs["collection_name"] == "adiuva_vectors" @pytest.mark.asyncio async def test_search_calls_client_search(self) -> None: mock_client = self._qdrant_mock() with patch("app.storage.vector_store.QdrantClient", return_value=mock_client): store = self._store() await store.search("u1", b"query", top_k=5) mock_client.search.assert_called_once() @pytest.mark.asyncio async def test_search_passes_limit(self) -> None: mock_client = self._qdrant_mock() with patch("app.storage.vector_store.QdrantClient", return_value=mock_client): store = self._store() await store.search("u1", b"query", top_k=7) call_kwargs = mock_client.search.call_args[1] assert call_kwargs.get("limit") == 7 @pytest.mark.asyncio async def test_search_returns_vector_search_results(self) -> None: mock_client = self._qdrant_mock() with patch("app.storage.vector_store.QdrantClient", return_value=mock_client): store = self._store() results = await store.search("u1", b"query", top_k=5) assert len(results) == 1 assert isinstance(results[0], VectorSearchResult) assert results[0].id == "v1" assert results[0].score == 0.88 assert results[0].blob == b"qdrant-result" @pytest.mark.asyncio async def test_delete_calls_client_delete(self) -> None: mock_client = MagicMock() with patch("app.storage.vector_store.QdrantClient", return_value=mock_client): store = self._store() await store.delete("u1", ["v1", "v2"]) mock_client.delete.assert_called_once() @pytest.mark.asyncio async def test_delete_uses_correct_collection(self) -> None: mock_client = MagicMock() with patch("app.storage.vector_store.QdrantClient", return_value=mock_client): store = self._store() await store.delete("u1", ["v1"]) call_kwargs = mock_client.delete.call_args[1] assert call_kwargs["collection_name"] == "adiuva_vectors" # ── TestStorageRoutes (integration) ─────────────────────────────────── class TestStorageRoutes: """Integration tests for POST/GET/PUT/DELETE /api/v1/storage/records. Pydantic v2 converts JSON string → bytes via ``str.encode('utf-8')``. So "hello" in JSON becomes ``b"hello"`` on the server. We use plain ASCII strings as blob values and compute checksums accordingly. """ _BLOB_STR = "encrypted-payload-opaque-to-server" _BLOB_BYTES = _BLOB_STR.encode() _BLOB_CHECKSUM = hashlib.sha256(_BLOB_BYTES).hexdigest() @classmethod def _create_payload(cls, blob_str: str | None = None) -> dict: blob_str = blob_str or cls._BLOB_STR checksum = hashlib.sha256(blob_str.encode()).hexdigest() return { "table": "tasks", "blob": blob_str, "checksum": checksum, } def _create_record(self, client, tier="power", blob_str=None): payload = self._create_payload(blob_str) return client.post( "/api/v1/storage/records", json=payload, headers=auth_header(tier), ) # ── Create ──────────────────────────────────────────────────────── def test_create_record(self, client, s3_bucket) -> None: resp = self._create_record(client) assert resp.status_code == 201 data = resp.json() assert "id" in data assert "created_at" in data def test_create_record_bad_checksum(self, client, s3_bucket) -> None: payload = { "table": "tasks", "blob": self._BLOB_STR, "checksum": "0" * 64, } resp = client.post( "/api/v1/storage/records", json=payload, headers=auth_header("power"), ) assert resp.status_code == 400 def test_create_record_free_tier_blocked(self, client, s3_bucket) -> None: """Free tier has cloud_storage_gb=0 → 402.""" resp = self._create_record(client, tier="free") assert resp.status_code == 402 def test_create_record_pro_tier_allowed(self, client, s3_bucket) -> None: """Pro tier has cloud_storage_gb=5 → succeeds for small blob.""" resp = self._create_record(client, tier="pro") assert resp.status_code == 201 # ── List ────────────────────────────────────────────────────────── def test_list_records(self, client, s3_bucket) -> None: self._create_record(client) self._create_record(client, blob_str="second-blob") resp = client.get( "/api/v1/storage/records", headers=auth_header("power"), ) assert resp.status_code == 200 data = resp.json() assert len(data) == 2 # Each entry has metadata, no blob bytes for item in data: assert "id" in item assert "table" in item assert "checksum" in item assert "blob" not in item def test_list_records_filter_by_table(self, client, s3_bucket) -> None: self._create_record(client) # Create in a different table note_blob = "note-blob" payload = { "table": "notes", "blob": note_blob, "checksum": hashlib.sha256(note_blob.encode()).hexdigest(), } client.post( "/api/v1/storage/records", json=payload, headers=auth_header("power"), ) resp = client.get( "/api/v1/storage/records?table=notes", headers=auth_header("power"), ) assert resp.status_code == 200 data = resp.json() assert len(data) == 1 assert data[0]["table"] == "notes" def test_list_records_isolated_per_user(self, client, s3_bucket) -> None: """One user's records should not appear in another user's list.""" self._create_record(client, tier="power") resp = client.get( "/api/v1/storage/records", headers=auth_header("team"), ) assert resp.json() == [] # ── Download ────────────────────────────────────────────────────── def test_download_record(self, client, s3_bucket) -> None: create_resp = self._create_record(client) record_id = create_resp.json()["id"] resp = client.get( f"/api/v1/storage/records/{record_id}", headers=auth_header("power"), ) assert resp.status_code == 200 assert resp.content == self._BLOB_BYTES assert resp.headers["X-Checksum"] == self._BLOB_CHECKSUM def test_download_record_not_found(self, client, s3_bucket) -> None: resp = client.get( "/api/v1/storage/records/nonexistent-id", headers=auth_header("power"), ) assert resp.status_code == 404 # ── Update ──────────────────────────────────────────────────────── def test_update_record(self, client, s3_bucket) -> None: create_resp = self._create_record(client) record_id = create_resp.json()["id"] new_blob_str = "updated-encrypted-payload" new_checksum = hashlib.sha256(new_blob_str.encode()).hexdigest() resp = client.put( f"/api/v1/storage/records/{record_id}", json={"blob": new_blob_str, "checksum": new_checksum}, headers=auth_header("power"), ) assert resp.status_code == 200 assert resp.json() == {"ok": True} # Verify download returns the updated blob dl = client.get( f"/api/v1/storage/records/{record_id}", headers=auth_header("power"), ) assert dl.content == new_blob_str.encode() def test_update_record_bad_checksum(self, client, s3_bucket) -> None: create_resp = self._create_record(client) record_id = create_resp.json()["id"] resp = client.put( f"/api/v1/storage/records/{record_id}", json={"blob": "some-data", "checksum": "0" * 64}, headers=auth_header("power"), ) assert resp.status_code == 400 # ── Delete ──────────────────────────────────────────────────────── def test_delete_record(self, client, s3_bucket) -> None: create_resp = self._create_record(client) record_id = create_resp.json()["id"] resp = client.delete( f"/api/v1/storage/records/{record_id}", headers=auth_header("power"), ) assert resp.status_code == 200 assert resp.json() == {"ok": True} # Subsequent GET should return 404 dl = client.get( f"/api/v1/storage/records/{record_id}", headers=auth_header("power"), ) assert dl.status_code == 404 def test_delete_record_not_found(self, client, s3_bucket) -> None: resp = client.delete( "/api/v1/storage/records/nonexistent", headers=auth_header("power"), ) assert resp.status_code == 404