refactor: remove storage, backup, plugin/marketplace features

- Delete app/storage/ (blob_store, vector_store, encryption)
- Delete app/marketplace/ (plugin_registry, plugin_review, revenue_share)
- Delete routes: backup.py, plugins.py, storage.py, vectors.py
- Relocate embed endpoint to POST /chat/embed
- Rewrite migration 001 (remove storage/plugin tables)
- Delete migration 002 (seed_plugins)
- Remove S3/Pinecone/Qdrant env vars from settings
- Remove storage/backup quotas from tier_manager
- Remove MinIO and Qdrant from docker-compose
- Delete tests: test_backup, test_plugins, test_storage
- Update README.md and clean .env.example
This commit is contained in:
Roberto Musso
2026-04-08 00:47:37 +02:00
parent aa8bcbf0d8
commit 5753f8def9
28 changed files with 89 additions and 3570 deletions

View File

@@ -1,243 +0,0 @@
"""Tests for backup routes: upload, download, history, delete.
Exercises the backup lifecycle through the FastAPI TestClient against the
in-memory SQLite test database and moto-mocked S3 bucket.
"""
from __future__ import annotations
import hashlib
from tests.conftest import auth_header, TEST_USER_IDS
# ── Helpers ───────────────────────────────────────────────────────────
_BLOB = b"encrypted-backup-blob-opaque-bytes"
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
_VERSION = 1
_TIMESTAMP = 1700000000000 # arbitrary ms timestamp
def _backup_headers(tier: str = "power", **overrides) -> dict[str, str]:
"""Return auth + backup metadata headers."""
headers = auth_header(tier)
headers["X-Backup-Version"] = str(overrides.get("version", _VERSION))
headers["X-Backup-Timestamp"] = str(overrides.get("timestamp", _TIMESTAMP))
headers["X-Backup-Checksum"] = overrides.get("checksum", _CHECKSUM)
headers["Content-Type"] = "application/octet-stream"
return headers
def _upload(client, tier="power", **overrides) -> "Response": # noqa: F821
"""Upload a backup blob and return the response."""
return client.put(
"/api/v1/backup",
content=overrides.pop("blob", _BLOB),
headers=_backup_headers(tier, **overrides),
)
# ── TestUploadBackup ──────────────────────────────────────────────────
class TestUploadBackup:
"""PUT /api/v1/backup"""
def test_upload_success(self, client, s3_bucket) -> None:
resp = _upload(client, tier="power")
assert resp.status_code == 200
assert resp.json() == {"ok": True}
def test_upload_creates_history_entry(self, client, s3_bucket) -> None:
_upload(client, tier="power")
history = client.get(
"/api/v1/backup/history", headers=auth_header("power")
).json()
assert len(history) == 1
assert history[0]["version"] == _VERSION
assert history[0]["timestamp"] == _TIMESTAMP
assert history[0]["checksum"] == _CHECKSUM
def test_upload_bad_checksum(self, client, s3_bucket) -> None:
resp = _upload(client, tier="power", checksum="0" * 64)
assert resp.status_code == 400
def test_upload_free_tier_blocked(self, client, s3_bucket) -> None:
"""Free tier has backup_gb=0 → should return 402."""
resp = _upload(client, tier="free")
assert resp.status_code == 402
def test_upload_pro_tier_allowed(self, client, s3_bucket) -> None:
"""Pro tier has backup_gb=5 → small blob succeeds."""
resp = _upload(client, tier="pro")
assert resp.status_code == 200
# ── TestDownloadBackup ────────────────────────────────────────────────
class TestDownloadBackup:
"""GET /api/v1/backup"""
def test_download_latest(self, client, s3_bucket) -> None:
_upload(client, tier="power")
resp = client.get("/api/v1/backup", headers=auth_header("power"))
assert resp.status_code == 200
assert resp.content == _BLOB
assert resp.headers["X-Checksum"] == _CHECKSUM
assert resp.headers["X-Backup-Version"] == str(_VERSION)
def test_download_no_backup_returns_404(self, client, s3_bucket) -> None:
resp = client.get("/api/v1/backup", headers=auth_header("power"))
assert resp.status_code == 404
def test_download_if_modified_since_returns_304(self, client, s3_bucket) -> None:
"""When If-Modified-Since is after the backup timestamp → 304."""
_upload(client, tier="power", timestamp=1700000000000)
resp = client.get(
"/api/v1/backup",
headers={
**auth_header("power"),
"If-Modified-Since": "Thu, 01 Jan 2099 00:00:00 GMT",
},
)
assert resp.status_code == 304
def test_download_if_modified_since_returns_200(self, client, s3_bucket) -> None:
"""When If-Modified-Since is before the backup timestamp → serve blob."""
_upload(client, tier="power", timestamp=1700000000000)
resp = client.get(
"/api/v1/backup",
headers={
**auth_header("power"),
"If-Modified-Since": "Thu, 01 Jan 2000 00:00:00 GMT",
},
)
assert resp.status_code == 200
assert resp.content == _BLOB
def test_download_multiple_returns_latest(self, client, s3_bucket) -> None:
"""When multiple backups exist, GET returns the one with the highest timestamp."""
_upload(client, tier="power", timestamp=1000)
blob2 = b"second-encrypted-backup"
checksum2 = hashlib.sha256(blob2).hexdigest()
_upload(client, tier="power", timestamp=2000, blob=blob2, checksum=checksum2)
resp = client.get("/api/v1/backup", headers=auth_header("power"))
assert resp.status_code == 200
assert resp.content == blob2
# ── TestBackupHistory ─────────────────────────────────────────────────
class TestBackupHistory:
"""GET /api/v1/backup/history"""
def test_history_empty(self, client, s3_bucket) -> None:
resp = client.get("/api/v1/backup/history", headers=auth_header("power"))
assert resp.status_code == 200
assert resp.json() == []
def test_history_returns_entries(self, client, s3_bucket) -> None:
_upload(client, tier="power", timestamp=1000)
_upload(client, tier="power", timestamp=2000)
history = client.get(
"/api/v1/backup/history", headers=auth_header("power")
).json()
assert len(history) == 2
# Ordered by timestamp descending
assert history[0]["timestamp"] == 2000
assert history[1]["timestamp"] == 1000
def test_history_isolated_per_user(self, client, s3_bucket) -> None:
"""One user's backups should not appear in another user's history."""
_upload(client, tier="power")
resp = client.get("/api/v1/backup/history", headers=auth_header("team"))
assert resp.json() == []
# ── TestDeleteBackup ──────────────────────────────────────────────────
class TestDeleteBackup:
"""DELETE /api/v1/backup/{backup_id}"""
def _get_backup_id(self, client, tier="power") -> str:
"""Upload a backup and return its DB id from history."""
_upload(client, tier=tier)
client.get(
"/api/v1/backup/history", headers=auth_header(tier)
).json()
# History returns BackupMetadata schema which doesn't have `id`.
# We need to look it up via a different means.
# Since there's only 1 backup, find via history length.
# Actually the schema doesn't return id — let's verify via re-download.
# We'll use a workaround: upload, then list history to confirm it exists,
# then try to delete — but we need the id...
# Let's check if history includes an id field.
# The schema is: version, timestamp, checksum, chunk_count — no id.
# We'll need to query the DB directly or use a known ID.
# For testing, we'll search history then use the DB.
return None # pragma: no cover — overridden below
def test_delete_success(self, client, s3_bucket, db_session) -> None:
_upload(client, tier="power")
# Discover the backup_id via direct DB query
import asyncio
from sqlalchemy import select
from app.models import BackupMetadata
async def _get_id():
result = await db_session.execute(
select(BackupMetadata.id).where(
BackupMetadata.user_id == TEST_USER_IDS["power"]
)
)
return result.scalar_one()
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
resp = client.delete(
f"/api/v1/backup/{backup_id}", headers=auth_header("power")
)
assert resp.status_code == 200
assert resp.json() == {"ok": True}
# History should now be empty
history = client.get(
"/api/v1/backup/history", headers=auth_header("power")
).json()
assert history == []
def test_delete_nonexistent(self, client, s3_bucket) -> None:
resp = client.delete(
"/api/v1/backup/no-such-id", headers=auth_header("power")
)
assert resp.status_code == 404
def test_delete_other_users_backup(self, client, s3_bucket, db_session) -> None:
"""Cannot delete another user's backup (ownership check returns 404)."""
_upload(client, tier="power")
import asyncio
from sqlalchemy import select
from app.models import BackupMetadata
async def _get_id():
result = await db_session.execute(
select(BackupMetadata.id).where(
BackupMetadata.user_id == TEST_USER_IDS["power"]
)
)
return result.scalar_one()
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
# team user tries to delete power user's backup → 404
resp = client.delete(
f"/api/v1/backup/{backup_id}", headers=auth_header("team")
)
assert resp.status_code == 404

View File

@@ -1,400 +0,0 @@
"""Tests for Step 10+12: Plugin Marketplace (DB-backed).
Covers:
- PluginRegistry: catalog management, filtering, sorting, install counts (PostgreSQL)
- ReviewQueue: pending queue, review decisions, manifest security checklist
- RevenueShare: install event recording, earnings aggregation (PostgreSQL)
- Route integration: tier gate, list/get/install/uninstall via TestClient
"""
from __future__ import annotations
import uuid
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.marketplace.plugin_registry import PluginRegistry
from app.marketplace.plugin_review import ReviewQueue, validate_manifest
from app.marketplace.revenue_share import RevenueShare
from app.models import Plugin, PluginReview as PluginReviewModel, RevenueEvent
from app.schemas import PluginManifest
from tests.conftest import TEST_USER_IDS, auth_header
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _fresh_manifest(
plugin_id: str | None = None,
category: str = "productivity",
price_cents: int = 0,
permissions: list[str] | None = None,
) -> PluginManifest:
pid = plugin_id or f"plugin-{uuid.uuid4().hex[:8]}"
return PluginManifest(
id=pid,
name=f"Plugin {pid}",
description=f"Description for {pid}",
version="1.0.0",
author="test-author",
permissions=permissions or ["read:tasks"],
category=category,
price_cents=price_cents,
)
# ---------------------------------------------------------------------------
# PluginRegistry (DB-backed)
# ---------------------------------------------------------------------------
class TestPluginRegistry:
"""Each test uses the conftest db_session fixture with a fresh in-memory DB."""
@pytest.fixture
def reg(self) -> PluginRegistry:
return PluginRegistry()
@pytest.mark.asyncio
async def test_seed_plugins_are_listed(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
result = await reg.list_plugins(db_session)
assert result.total == 3
assert all(p.id.startswith("plugin-") for p in result.plugins)
@pytest.mark.asyncio
async def test_list_approved_only(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "plugins/key.zip")
result = await reg.list_plugins(db_session)
ids = [p.id for p in result.plugins]
assert manifest.id not in ids # still pending
@pytest.mark.asyncio
async def test_list_filter_by_category(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
result = await reg.list_plugins(db_session, category="communication")
assert result.total == 1
assert result.plugins[0].id == "plugin-slack-notify"
@pytest.mark.asyncio
async def test_list_filter_by_query(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
result = await reg.list_plugins(db_session, query="time")
assert result.total == 1
assert result.plugins[0].id == "plugin-time-tracker"
@pytest.mark.asyncio
async def test_list_sort_by_installs(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_install(db_session, "plugin-slack-notify")
await reg.record_install(db_session, "plugin-slack-notify")
result = await reg.list_plugins(db_session, sort="installs")
assert result.plugins[0].id == "plugin-slack-notify"
@pytest.mark.asyncio
async def test_get_plugin_found(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["manifest"].id == "plugin-github-sync"
assert "install_count" in entry
@pytest.mark.asyncio
async def test_get_plugin_not_found(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
entry = await reg.get_plugin(db_session, "no-such-plugin")
assert entry is None
@pytest.mark.asyncio
async def test_submit_sets_pending(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
plugin_id = await reg.submit_plugin(db_session, manifest, "key.zip")
assert plugin_id == manifest.id
result = await db_session.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one()
assert row.status == "pending_review"
@pytest.mark.asyncio
async def test_approve_makes_visible(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
await reg.approve_plugin(db_session, manifest.id)
result = await reg.list_plugins(db_session)
assert manifest.id in [p.id for p in result.plugins]
@pytest.mark.asyncio
async def test_reject_stores_reason(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
await reg.reject_plugin(db_session, manifest.id, reason="Unsafe permissions")
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
row = result.scalar_one()
assert row.status == "rejected"
assert row.rejection_reason == "Unsafe permissions"
listed = await reg.list_plugins(db_session)
assert manifest.id not in [p.id for p in listed.plugins]
@pytest.mark.asyncio
async def test_approve_unknown_raises_key_error(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
with pytest.raises(KeyError):
await reg.approve_plugin(db_session, "ghost-plugin")
@pytest.mark.asyncio
async def test_record_install_increments_count(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_install(db_session, "plugin-github-sync")
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_record_uninstall_decrements_count(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_install(db_session, "plugin-github-sync")
await reg.record_install(db_session, "plugin-github-sync")
await reg.record_uninstall(db_session, "plugin-github-sync")
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_record_uninstall_floors_at_zero(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_uninstall(db_session, "plugin-github-sync")
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 0
# ---------------------------------------------------------------------------
# ReviewQueue (DB-backed)
# ---------------------------------------------------------------------------
class TestReviewQueue:
@pytest.fixture
def reg(self) -> PluginRegistry:
return PluginRegistry()
@pytest.fixture
def queue(self) -> ReviewQueue:
return ReviewQueue()
@pytest.mark.asyncio
async def test_get_pending_returns_submitted_plugins(
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
pending = await queue.get_pending(db_session)
assert any(p["plugin_id"] == manifest.id for p in pending)
@pytest.mark.asyncio
async def test_submit_review_approved(
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
await queue.submit_review(db_session, manifest.id, TEST_USER_IDS["power"], "approved", "Looks good")
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
row = result.scalar_one()
assert row.status == "approved"
# Check review row was persisted
review_result = await db_session.execute(
select(PluginReviewModel).where(PluginReviewModel.plugin_id == manifest.id)
)
review = review_result.scalar_one()
assert review.decision == "approved"
@pytest.mark.asyncio
async def test_submit_review_rejected(
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
await queue.submit_review(
db_session, manifest.id, TEST_USER_IDS["power"], "rejected", "Bad permissions"
)
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
row = result.scalar_one()
assert row.status == "rejected"
def test_validate_manifest_ok(self) -> None:
manifest = _fresh_manifest(permissions=["read:tasks", "write:notes"])
validate_manifest(manifest) # should not raise
def test_validate_manifest_unknown_permission(self) -> None:
manifest = _fresh_manifest(permissions=["read:tasks", "read:secrets"])
with pytest.raises(ValueError, match="Unknown permission"):
validate_manifest(manifest)
def test_validate_manifest_invalid_id_format(self) -> None:
manifest = _fresh_manifest(plugin_id="Plugin_ID_Invalid")
with pytest.raises(ValueError, match="Invalid plugin id format"):
validate_manifest(manifest)
def test_validate_manifest_id_with_uppercase(self) -> None:
manifest = _fresh_manifest(plugin_id="UpperCase")
with pytest.raises(ValueError, match="Invalid plugin id format"):
validate_manifest(manifest)
# ---------------------------------------------------------------------------
# RevenueShare (DB-backed)
# ---------------------------------------------------------------------------
class TestRevenueShare:
@pytest.fixture
def rs(self) -> RevenueShare:
return RevenueShare()
@pytest.mark.asyncio
async def test_record_install_free_plugin(
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
result = await db_session.execute(
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-github-sync")
)
event = result.scalar_one()
assert event.developer_share_cents == 0
@pytest.mark.asyncio
async def test_record_install_paid_plugin_no_stripe(
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await rs.record_install(
db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499
)
result = await db_session.execute(
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-slack-notify")
)
event = result.scalar_one()
assert event.amount_cents == 499
assert event.developer_share_cents == int(499 * 0.70)
@pytest.mark.asyncio
async def test_record_install_increments_registry_count(
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
reg = PluginRegistry()
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_get_earnings_empty(
self, rs: RevenueShare, db_session: AsyncSession
) -> None:
result = await rs.get_earnings(db_session, "unknown-dev")
assert result["total_installs"] == 0
assert result["total_revenue_cents"] == 0
assert result["developer_share_cents"] == 0
@pytest.mark.asyncio
async def test_get_earnings_aggregates(
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["power"], amount_cents=499)
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499)
result = await rs.get_earnings(db_session, "Adiuva")
assert result["total_installs"] == 2
assert result["total_revenue_cents"] == 998
assert result["developer_share_cents"] == int(499 * 0.70) * 2
# ---------------------------------------------------------------------------
# Route integration tests
# ---------------------------------------------------------------------------
class TestPluginRoutes:
def test_list_plugins_requires_power_tier(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("free"))
assert resp.status_code == 403
def test_list_plugins_pro_tier_blocked(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("pro"))
assert resp.status_code == 403
def test_list_plugins_power_tier_ok(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("power"))
assert resp.status_code == 200
data = resp.json()
assert "plugins" in data
assert data["total"] == 3
def test_list_plugins_team_tier_ok(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("team"))
assert resp.status_code == 200
def test_get_plugin_found(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins/plugin-github-sync", headers=auth_header())
assert resp.status_code == 200
data = resp.json()
assert data["plugin"]["id"] == "plugin-github-sync"
assert "install_count" in data
def test_get_plugin_not_found(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins/no-such-plugin", headers=auth_header())
assert resp.status_code == 404
def test_install_plugin_free(self, client, seed_plugins) -> None:
resp = client.post(
"/api/v1/plugins/plugin-github-sync/install",
json={"plugin_id": "plugin-github-sync"},
headers=auth_header(),
)
assert resp.status_code == 200
data = resp.json()
assert data["ok"] is True
assert "download_url" in data
def test_install_plugin_not_found(self, client, seed_plugins) -> None:
resp = client.post(
"/api/v1/plugins/ghost/install",
json={"plugin_id": "ghost"},
headers=auth_header(),
)
assert resp.status_code == 404
def test_uninstall_plugin_ok(self, client, seed_plugins) -> None:
resp = client.delete(
"/api/v1/plugins/plugin-github-sync/install",
headers=auth_header(),
)
assert resp.status_code == 200
assert resp.json()["ok"] is True
def test_install_requires_power_tier(self, client, seed_plugins) -> None:
resp = client.post(
"/api/v1/plugins/plugin-github-sync/install",
json={"plugin_id": "plugin-github-sync"},
headers=auth_header("free"),
)
assert resp.status_code == 403

View File

@@ -1,562 +0,0 @@
"""Tests for the storage layer: encryption, BlobStore, VectorStore, and storage routes."""
from __future__ import annotations
import base64
import hashlib
from unittest.mock import MagicMock, patch
import boto3
import pytest
from botocore.exceptions import ClientError
from app.storage.encryption import reject_if_tampered, verify_checksum
from app.storage.blob_store import BlobStore
from app.storage.vector_store import VectorStore, _blob_to_vector
from app.schemas import VectorItem, VectorSearchResult
from tests.conftest import auth_header, S3_TEST_BUCKET
# ── Helpers ───────────────────────────────────────────────────────────
_BLOB = b"encrypted-payload-opaque-to-server"
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
_BUCKET = S3_TEST_BUCKET
_REGION = "us-east-1"
def _pinecone_mock():
"""Return a mock Pinecone index with realistic return shapes."""
mock_index = MagicMock()
mock_index.query.return_value = {
"matches": [
{
"id": "v1",
"score": 0.95,
"metadata": {
"blob": base64.b64encode(b"result-blob").decode(),
"checksum": hashlib.sha256(b"result-blob").hexdigest(),
"user_id": "u1",
},
}
]
}
mock_pc = MagicMock()
mock_pc.return_value.Index.return_value = mock_index
return mock_pc, mock_index
# ── TestEncryption ────────────────────────────────────────────────────
class TestEncryption:
def test_verify_checksum_correct(self) -> None:
assert verify_checksum(_BLOB, _CHECKSUM) is True
def test_verify_checksum_wrong(self) -> None:
assert verify_checksum(_BLOB, "0" * 64) is False
def test_verify_checksum_empty_checksum(self) -> None:
assert verify_checksum(_BLOB, "") is False
def test_verify_checksum_empty_blob(self) -> None:
expected = hashlib.sha256(b"").hexdigest()
assert verify_checksum(b"", expected) is True
def test_verify_checksum_tampered_blob(self) -> None:
tampered = _BLOB + b"\x00"
assert verify_checksum(tampered, _CHECKSUM) is False
def test_reject_if_tampered_passes_when_valid(self) -> None:
# Should not raise
reject_if_tampered(_BLOB, _CHECKSUM)
def test_reject_if_tampered_raises_400_on_mismatch(self) -> None:
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
reject_if_tampered(_BLOB, "bad" * 20)
assert exc_info.value.status_code == 400
def test_reject_if_tampered_detail_mentions_checksum(self) -> None:
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
reject_if_tampered(_BLOB, "bad" * 20)
assert "checksum" in exc_info.value.detail.lower()
def test_checksum_is_sha256_hex(self) -> None:
cs = hashlib.sha256(_BLOB).hexdigest()
assert len(cs) == 64
assert all(c in "0123456789abcdef" for c in cs)
# ── TestBlobStore ─────────────────────────────────────────────────────
class TestBlobStore:
@pytest.mark.asyncio
async def test_upload_returns_correct_key(self, s3_bucket: str) -> None:
store = BlobStore()
key = await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
assert key == "u1/tasks/r1"
@pytest.mark.asyncio
async def test_upload_object_exists_in_s3(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
# Verify by downloading — no exception means object exists
retrieved = await store.download("u1", "u1/tasks/r1")
assert retrieved == _BLOB
@pytest.mark.asyncio
async def test_download_retrieves_same_bytes(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "notes", "n1", b"note-data", hashlib.sha256(b"note-data").hexdigest())
result = await store.download("u1", "u1/notes/n1")
assert result == b"note-data"
@pytest.mark.asyncio
async def test_delete_removes_object(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
await store.delete("u1", "u1/tasks/r1")
with pytest.raises(ClientError) as exc_info:
await store.download("u1", "u1/tasks/r1")
assert exc_info.value.response["Error"]["Code"] == "NoSuchKey"
@pytest.mark.asyncio
async def test_delete_is_idempotent(self, s3_bucket: str) -> None:
store = BlobStore()
# Delete a key that never existed — should not raise
await store.delete("u1", "u1/tasks/nonexistent")
@pytest.mark.asyncio
async def test_list_keys_returns_correct_keys(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
await store.upload("u1", "tasks", "r2", _BLOB, _CHECKSUM)
keys = await store.list_keys("u1", "tasks")
assert set(keys) == {"u1/tasks/r1", "u1/tasks/r2"}
@pytest.mark.asyncio
async def test_list_keys_scoped_to_table(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
await store.upload("u1", "notes", "n1", _BLOB, _CHECKSUM)
keys = await store.list_keys("u1", "tasks")
assert "u1/notes/n1" not in keys
assert "u1/tasks/r1" in keys
@pytest.mark.asyncio
async def test_list_keys_no_cross_user_leakage(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
await store.upload("u2", "tasks", "r1", _BLOB, _CHECKSUM)
keys_u1 = await store.list_keys("u1", "tasks")
assert "u2/tasks/r1" not in keys_u1
@pytest.mark.asyncio
async def test_list_keys_empty_table(self, s3_bucket: str) -> None:
store = BlobStore()
keys = await store.list_keys("u1", "tasks")
assert keys == []
@pytest.mark.asyncio
async def test_upload_uses_sse_s3_encryption(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
# Verify S3 metadata was set — check via head_object
with patch("app.storage.blob_store.settings") as mock_settings:
mock_settings.S3_BUCKET = _BUCKET
mock_settings.S3_REGION = _REGION
mock_settings.AWS_ACCESS_KEY_ID = "testing"
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
client = boto3.client("s3", region_name=_REGION)
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
assert response.get("ServerSideEncryption") == "AES256"
@pytest.mark.asyncio
async def test_upload_stores_checksum_in_metadata(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
client = boto3.client("s3", region_name=_REGION)
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
assert response["Metadata"]["checksum"] == _CHECKSUM
# ── _blob_to_vector helper ────────────────────────────────────────────
class TestBlobToVector:
def test_returns_32_floats(self) -> None:
v = _blob_to_vector(b"test")
assert len(v) == 32
def test_all_values_in_range(self) -> None:
v = _blob_to_vector(b"test")
assert all(-1.0 <= x <= 1.0 for x in v)
def test_deterministic(self) -> None:
assert _blob_to_vector(b"same") == _blob_to_vector(b"same")
def test_different_blobs_different_vectors(self) -> None:
assert _blob_to_vector(b"aaa") != _blob_to_vector(b"bbb")
# ── TestVectorStorePinecone ───────────────────────────────────────────
class TestVectorStorePinecone:
def _store(self) -> VectorStore:
store = VectorStore()
store._use_pinecone = lambda: True # type: ignore[method-assign]
return store
@pytest.mark.asyncio
async def test_upsert_calls_index_upsert(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
items = [VectorItem(id="v1", blob=b"enc-blob", checksum=hashlib.sha256(b"enc-blob").hexdigest())]
await store.upsert("u1", items)
mock_index.upsert.assert_called_once()
call_kwargs = mock_index.upsert.call_args[1]
assert call_kwargs.get("namespace") == "u1"
@pytest.mark.asyncio
async def test_upsert_encodes_blob_as_base64_in_metadata(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
items = [VectorItem(id="v1", blob=b"secret", checksum=hashlib.sha256(b"secret").hexdigest())]
await store.upsert("u1", items)
vectors_arg = mock_index.upsert.call_args[1]["vectors"]
assert vectors_arg[0]["metadata"]["blob"] == base64.b64encode(b"secret").decode()
@pytest.mark.asyncio
async def test_search_calls_index_query(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
await store.search("u1", b"query-blob", top_k=5)
mock_index.query.assert_called_once()
query_kwargs = mock_index.query.call_args[1]
assert query_kwargs.get("namespace") == "u1"
assert query_kwargs.get("top_k") == 5
assert query_kwargs.get("include_metadata") is True
@pytest.mark.asyncio
async def test_search_returns_vector_search_results(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
results = await store.search("u1", b"query", top_k=10)
assert len(results) == 1
assert isinstance(results[0], VectorSearchResult)
assert results[0].id == "v1"
assert results[0].score == 0.95
assert results[0].blob == b"result-blob"
@pytest.mark.asyncio
async def test_search_uses_derived_query_vector(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
await store.search("u1", b"query-blob", top_k=3)
expected_vector = _blob_to_vector(b"query-blob")
actual_vector = mock_index.query.call_args[1].get("vector")
assert actual_vector == expected_vector
@pytest.mark.asyncio
async def test_delete_calls_index_delete(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
await store.delete("u1", ["v1", "v2"])
mock_index.delete.assert_called_once()
delete_kwargs = mock_index.delete.call_args[1]
assert delete_kwargs.get("namespace") == "u1"
assert set(delete_kwargs.get("ids", [])) == {"v1", "v2"}
# ── TestVectorStoreQdrant ─────────────────────────────────────────────
class TestVectorStoreQdrant:
def _store(self) -> VectorStore:
store = VectorStore()
store._use_pinecone = lambda: False # type: ignore[method-assign]
return store
def _qdrant_mock(self) -> MagicMock:
mock_hit = MagicMock()
mock_hit.id = "v1"
mock_hit.score = 0.88
mock_hit.payload = {
"blob": base64.b64encode(b"qdrant-result").decode(),
"user_id": "u1",
}
mock_client = MagicMock()
mock_client.search.return_value = [mock_hit]
return mock_client
@pytest.mark.asyncio
async def test_upsert_calls_client_upsert(self) -> None:
mock_client = MagicMock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
await store.upsert("u1", items)
mock_client.upsert.assert_called_once()
@pytest.mark.asyncio
async def test_upsert_uses_correct_collection(self) -> None:
mock_client = MagicMock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
await store.upsert("u1", items)
call_kwargs = mock_client.upsert.call_args[1]
assert call_kwargs["collection_name"] == "adiuva_vectors"
@pytest.mark.asyncio
async def test_search_calls_client_search(self) -> None:
mock_client = self._qdrant_mock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
await store.search("u1", b"query", top_k=5)
mock_client.search.assert_called_once()
@pytest.mark.asyncio
async def test_search_passes_limit(self) -> None:
mock_client = self._qdrant_mock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
await store.search("u1", b"query", top_k=7)
call_kwargs = mock_client.search.call_args[1]
assert call_kwargs.get("limit") == 7
@pytest.mark.asyncio
async def test_search_returns_vector_search_results(self) -> None:
mock_client = self._qdrant_mock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
results = await store.search("u1", b"query", top_k=5)
assert len(results) == 1
assert isinstance(results[0], VectorSearchResult)
assert results[0].id == "v1"
assert results[0].score == 0.88
assert results[0].blob == b"qdrant-result"
@pytest.mark.asyncio
async def test_delete_calls_client_delete(self) -> None:
mock_client = MagicMock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
await store.delete("u1", ["v1", "v2"])
mock_client.delete.assert_called_once()
@pytest.mark.asyncio
async def test_delete_uses_correct_collection(self) -> None:
mock_client = MagicMock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
await store.delete("u1", ["v1"])
call_kwargs = mock_client.delete.call_args[1]
assert call_kwargs["collection_name"] == "adiuva_vectors"
# ── TestStorageRoutes (integration) ───────────────────────────────────
class TestStorageRoutes:
"""Integration tests for POST/GET/PUT/DELETE /api/v1/storage/records.
Pydantic v2 converts JSON string → bytes via ``str.encode('utf-8')``.
So "hello" in JSON becomes ``b"hello"`` on the server. We use plain
ASCII strings as blob values and compute checksums accordingly.
"""
_BLOB_STR = "encrypted-payload-opaque-to-server"
_BLOB_BYTES = _BLOB_STR.encode()
_BLOB_CHECKSUM = hashlib.sha256(_BLOB_BYTES).hexdigest()
@classmethod
def _create_payload(cls, blob_str: str | None = None) -> dict:
blob_str = blob_str or cls._BLOB_STR
checksum = hashlib.sha256(blob_str.encode()).hexdigest()
return {
"table": "tasks",
"blob": blob_str,
"checksum": checksum,
}
def _create_record(self, client, tier="power", blob_str=None):
payload = self._create_payload(blob_str)
return client.post(
"/api/v1/storage/records",
json=payload,
headers=auth_header(tier),
)
# ── Create ────────────────────────────────────────────────────────
def test_create_record(self, client, s3_bucket) -> None:
resp = self._create_record(client)
assert resp.status_code == 201
data = resp.json()
assert "id" in data
assert "created_at" in data
def test_create_record_bad_checksum(self, client, s3_bucket) -> None:
payload = {
"table": "tasks",
"blob": self._BLOB_STR,
"checksum": "0" * 64,
}
resp = client.post(
"/api/v1/storage/records",
json=payload,
headers=auth_header("power"),
)
assert resp.status_code == 400
def test_create_record_free_tier_blocked(self, client, s3_bucket) -> None:
"""Free tier has cloud_storage_gb=0 → 402."""
resp = self._create_record(client, tier="free")
assert resp.status_code == 402
def test_create_record_pro_tier_allowed(self, client, s3_bucket) -> None:
"""Pro tier has cloud_storage_gb=5 → succeeds for small blob."""
resp = self._create_record(client, tier="pro")
assert resp.status_code == 201
# ── List ──────────────────────────────────────────────────────────
def test_list_records(self, client, s3_bucket) -> None:
self._create_record(client)
self._create_record(client, blob_str="second-blob")
resp = client.get(
"/api/v1/storage/records",
headers=auth_header("power"),
)
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2
# Each entry has metadata, no blob bytes
for item in data:
assert "id" in item
assert "table" in item
assert "checksum" in item
assert "blob" not in item
def test_list_records_filter_by_table(self, client, s3_bucket) -> None:
self._create_record(client)
# Create in a different table
note_blob = "note-blob"
payload = {
"table": "notes",
"blob": note_blob,
"checksum": hashlib.sha256(note_blob.encode()).hexdigest(),
}
client.post(
"/api/v1/storage/records",
json=payload,
headers=auth_header("power"),
)
resp = client.get(
"/api/v1/storage/records?table=notes",
headers=auth_header("power"),
)
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["table"] == "notes"
def test_list_records_isolated_per_user(self, client, s3_bucket) -> None:
"""One user's records should not appear in another user's list."""
self._create_record(client, tier="power")
resp = client.get(
"/api/v1/storage/records",
headers=auth_header("team"),
)
assert resp.json() == []
# ── Download ──────────────────────────────────────────────────────
def test_download_record(self, client, s3_bucket) -> None:
create_resp = self._create_record(client)
record_id = create_resp.json()["id"]
resp = client.get(
f"/api/v1/storage/records/{record_id}",
headers=auth_header("power"),
)
assert resp.status_code == 200
assert resp.content == self._BLOB_BYTES
assert resp.headers["X-Checksum"] == self._BLOB_CHECKSUM
def test_download_record_not_found(self, client, s3_bucket) -> None:
resp = client.get(
"/api/v1/storage/records/nonexistent-id",
headers=auth_header("power"),
)
assert resp.status_code == 404
# ── Update ────────────────────────────────────────────────────────
def test_update_record(self, client, s3_bucket) -> None:
create_resp = self._create_record(client)
record_id = create_resp.json()["id"]
new_blob_str = "updated-encrypted-payload"
new_checksum = hashlib.sha256(new_blob_str.encode()).hexdigest()
resp = client.put(
f"/api/v1/storage/records/{record_id}",
json={"blob": new_blob_str, "checksum": new_checksum},
headers=auth_header("power"),
)
assert resp.status_code == 200
assert resp.json() == {"ok": True}
# Verify download returns the updated blob
dl = client.get(
f"/api/v1/storage/records/{record_id}",
headers=auth_header("power"),
)
assert dl.content == new_blob_str.encode()
def test_update_record_bad_checksum(self, client, s3_bucket) -> None:
create_resp = self._create_record(client)
record_id = create_resp.json()["id"]
resp = client.put(
f"/api/v1/storage/records/{record_id}",
json={"blob": "some-data", "checksum": "0" * 64},
headers=auth_header("power"),
)
assert resp.status_code == 400
# ── Delete ────────────────────────────────────────────────────────
def test_delete_record(self, client, s3_bucket) -> None:
create_resp = self._create_record(client)
record_id = create_resp.json()["id"]
resp = client.delete(
f"/api/v1/storage/records/{record_id}",
headers=auth_header("power"),
)
assert resp.status_code == 200
assert resp.json() == {"ok": True}
# Subsequent GET should return 404
dl = client.get(
f"/api/v1/storage/records/{record_id}",
headers=auth_header("power"),
)
assert dl.status_code == 404
def test_delete_record_not_found(self, client, s3_bucket) -> None:
resp = client.delete(
"/api/v1/storage/records/nonexistent",
headers=auth_header("power"),
)
assert resp.status_code == 404