refactor: remove storage, backup, plugin/marketplace features
- Delete app/storage/ (blob_store, vector_store, encryption) - Delete app/marketplace/ (plugin_registry, plugin_review, revenue_share) - Delete routes: backup.py, plugins.py, storage.py, vectors.py - Relocate embed endpoint to POST /chat/embed - Rewrite migration 001 (remove storage/plugin tables) - Delete migration 002 (seed_plugins) - Remove S3/Pinecone/Qdrant env vars from settings - Remove storage/backup quotas from tier_manager - Remove MinIO and Qdrant from docker-compose - Delete tests: test_backup, test_plugins, test_storage - Update README.md and clean .env.example
This commit is contained in:
@@ -1,243 +0,0 @@
|
||||
"""Tests for backup routes: upload, download, history, delete.
|
||||
|
||||
Exercises the backup lifecycle through the FastAPI TestClient against the
|
||||
in-memory SQLite test database and moto-mocked S3 bucket.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
|
||||
|
||||
from tests.conftest import auth_header, TEST_USER_IDS
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
_BLOB = b"encrypted-backup-blob-opaque-bytes"
|
||||
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
|
||||
_VERSION = 1
|
||||
_TIMESTAMP = 1700000000000 # arbitrary ms timestamp
|
||||
|
||||
|
||||
def _backup_headers(tier: str = "power", **overrides) -> dict[str, str]:
|
||||
"""Return auth + backup metadata headers."""
|
||||
headers = auth_header(tier)
|
||||
headers["X-Backup-Version"] = str(overrides.get("version", _VERSION))
|
||||
headers["X-Backup-Timestamp"] = str(overrides.get("timestamp", _TIMESTAMP))
|
||||
headers["X-Backup-Checksum"] = overrides.get("checksum", _CHECKSUM)
|
||||
headers["Content-Type"] = "application/octet-stream"
|
||||
return headers
|
||||
|
||||
|
||||
def _upload(client, tier="power", **overrides) -> "Response": # noqa: F821
|
||||
"""Upload a backup blob and return the response."""
|
||||
return client.put(
|
||||
"/api/v1/backup",
|
||||
content=overrides.pop("blob", _BLOB),
|
||||
headers=_backup_headers(tier, **overrides),
|
||||
)
|
||||
|
||||
|
||||
# ── TestUploadBackup ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestUploadBackup:
|
||||
"""PUT /api/v1/backup"""
|
||||
|
||||
def test_upload_success(self, client, s3_bucket) -> None:
|
||||
resp = _upload(client, tier="power")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"ok": True}
|
||||
|
||||
def test_upload_creates_history_entry(self, client, s3_bucket) -> None:
|
||||
_upload(client, tier="power")
|
||||
history = client.get(
|
||||
"/api/v1/backup/history", headers=auth_header("power")
|
||||
).json()
|
||||
assert len(history) == 1
|
||||
assert history[0]["version"] == _VERSION
|
||||
assert history[0]["timestamp"] == _TIMESTAMP
|
||||
assert history[0]["checksum"] == _CHECKSUM
|
||||
|
||||
def test_upload_bad_checksum(self, client, s3_bucket) -> None:
|
||||
resp = _upload(client, tier="power", checksum="0" * 64)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_upload_free_tier_blocked(self, client, s3_bucket) -> None:
|
||||
"""Free tier has backup_gb=0 → should return 402."""
|
||||
resp = _upload(client, tier="free")
|
||||
assert resp.status_code == 402
|
||||
|
||||
def test_upload_pro_tier_allowed(self, client, s3_bucket) -> None:
|
||||
"""Pro tier has backup_gb=5 → small blob succeeds."""
|
||||
resp = _upload(client, tier="pro")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ── TestDownloadBackup ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDownloadBackup:
|
||||
"""GET /api/v1/backup"""
|
||||
|
||||
def test_download_latest(self, client, s3_bucket) -> None:
|
||||
_upload(client, tier="power")
|
||||
resp = client.get("/api/v1/backup", headers=auth_header("power"))
|
||||
assert resp.status_code == 200
|
||||
assert resp.content == _BLOB
|
||||
assert resp.headers["X-Checksum"] == _CHECKSUM
|
||||
assert resp.headers["X-Backup-Version"] == str(_VERSION)
|
||||
|
||||
def test_download_no_backup_returns_404(self, client, s3_bucket) -> None:
|
||||
resp = client.get("/api/v1/backup", headers=auth_header("power"))
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_download_if_modified_since_returns_304(self, client, s3_bucket) -> None:
|
||||
"""When If-Modified-Since is after the backup timestamp → 304."""
|
||||
_upload(client, tier="power", timestamp=1700000000000)
|
||||
resp = client.get(
|
||||
"/api/v1/backup",
|
||||
headers={
|
||||
**auth_header("power"),
|
||||
"If-Modified-Since": "Thu, 01 Jan 2099 00:00:00 GMT",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 304
|
||||
|
||||
def test_download_if_modified_since_returns_200(self, client, s3_bucket) -> None:
|
||||
"""When If-Modified-Since is before the backup timestamp → serve blob."""
|
||||
_upload(client, tier="power", timestamp=1700000000000)
|
||||
resp = client.get(
|
||||
"/api/v1/backup",
|
||||
headers={
|
||||
**auth_header("power"),
|
||||
"If-Modified-Since": "Thu, 01 Jan 2000 00:00:00 GMT",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.content == _BLOB
|
||||
|
||||
def test_download_multiple_returns_latest(self, client, s3_bucket) -> None:
|
||||
"""When multiple backups exist, GET returns the one with the highest timestamp."""
|
||||
_upload(client, tier="power", timestamp=1000)
|
||||
blob2 = b"second-encrypted-backup"
|
||||
checksum2 = hashlib.sha256(blob2).hexdigest()
|
||||
_upload(client, tier="power", timestamp=2000, blob=blob2, checksum=checksum2)
|
||||
resp = client.get("/api/v1/backup", headers=auth_header("power"))
|
||||
assert resp.status_code == 200
|
||||
assert resp.content == blob2
|
||||
|
||||
|
||||
# ── TestBackupHistory ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBackupHistory:
|
||||
"""GET /api/v1/backup/history"""
|
||||
|
||||
def test_history_empty(self, client, s3_bucket) -> None:
|
||||
resp = client.get("/api/v1/backup/history", headers=auth_header("power"))
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
def test_history_returns_entries(self, client, s3_bucket) -> None:
|
||||
_upload(client, tier="power", timestamp=1000)
|
||||
_upload(client, tier="power", timestamp=2000)
|
||||
history = client.get(
|
||||
"/api/v1/backup/history", headers=auth_header("power")
|
||||
).json()
|
||||
assert len(history) == 2
|
||||
# Ordered by timestamp descending
|
||||
assert history[0]["timestamp"] == 2000
|
||||
assert history[1]["timestamp"] == 1000
|
||||
|
||||
def test_history_isolated_per_user(self, client, s3_bucket) -> None:
|
||||
"""One user's backups should not appear in another user's history."""
|
||||
_upload(client, tier="power")
|
||||
resp = client.get("/api/v1/backup/history", headers=auth_header("team"))
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
# ── TestDeleteBackup ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDeleteBackup:
|
||||
"""DELETE /api/v1/backup/{backup_id}"""
|
||||
|
||||
def _get_backup_id(self, client, tier="power") -> str:
|
||||
"""Upload a backup and return its DB id from history."""
|
||||
_upload(client, tier=tier)
|
||||
client.get(
|
||||
"/api/v1/backup/history", headers=auth_header(tier)
|
||||
).json()
|
||||
# History returns BackupMetadata schema which doesn't have `id`.
|
||||
# We need to look it up via a different means.
|
||||
# Since there's only 1 backup, find via history length.
|
||||
# Actually the schema doesn't return id — let's verify via re-download.
|
||||
# We'll use a workaround: upload, then list history to confirm it exists,
|
||||
# then try to delete — but we need the id...
|
||||
# Let's check if history includes an id field.
|
||||
# The schema is: version, timestamp, checksum, chunk_count — no id.
|
||||
# We'll need to query the DB directly or use a known ID.
|
||||
# For testing, we'll search history then use the DB.
|
||||
return None # pragma: no cover — overridden below
|
||||
|
||||
def test_delete_success(self, client, s3_bucket, db_session) -> None:
|
||||
_upload(client, tier="power")
|
||||
|
||||
# Discover the backup_id via direct DB query
|
||||
import asyncio
|
||||
from sqlalchemy import select
|
||||
from app.models import BackupMetadata
|
||||
|
||||
async def _get_id():
|
||||
result = await db_session.execute(
|
||||
select(BackupMetadata.id).where(
|
||||
BackupMetadata.user_id == TEST_USER_IDS["power"]
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
|
||||
|
||||
resp = client.delete(
|
||||
f"/api/v1/backup/{backup_id}", headers=auth_header("power")
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"ok": True}
|
||||
|
||||
# History should now be empty
|
||||
history = client.get(
|
||||
"/api/v1/backup/history", headers=auth_header("power")
|
||||
).json()
|
||||
assert history == []
|
||||
|
||||
def test_delete_nonexistent(self, client, s3_bucket) -> None:
|
||||
resp = client.delete(
|
||||
"/api/v1/backup/no-such-id", headers=auth_header("power")
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_delete_other_users_backup(self, client, s3_bucket, db_session) -> None:
|
||||
"""Cannot delete another user's backup (ownership check returns 404)."""
|
||||
_upload(client, tier="power")
|
||||
|
||||
import asyncio
|
||||
from sqlalchemy import select
|
||||
from app.models import BackupMetadata
|
||||
|
||||
async def _get_id():
|
||||
result = await db_session.execute(
|
||||
select(BackupMetadata.id).where(
|
||||
BackupMetadata.user_id == TEST_USER_IDS["power"]
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
|
||||
|
||||
# team user tries to delete power user's backup → 404
|
||||
resp = client.delete(
|
||||
f"/api/v1/backup/{backup_id}", headers=auth_header("team")
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
@@ -1,400 +0,0 @@
|
||||
"""Tests for Step 10+12: Plugin Marketplace (DB-backed).
|
||||
|
||||
Covers:
|
||||
- PluginRegistry: catalog management, filtering, sorting, install counts (PostgreSQL)
|
||||
- ReviewQueue: pending queue, review decisions, manifest security checklist
|
||||
- RevenueShare: install event recording, earnings aggregation (PostgreSQL)
|
||||
- Route integration: tier gate, list/get/install/uninstall via TestClient
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.marketplace.plugin_registry import PluginRegistry
|
||||
from app.marketplace.plugin_review import ReviewQueue, validate_manifest
|
||||
from app.marketplace.revenue_share import RevenueShare
|
||||
from app.models import Plugin, PluginReview as PluginReviewModel, RevenueEvent
|
||||
from app.schemas import PluginManifest
|
||||
from tests.conftest import TEST_USER_IDS, auth_header
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _fresh_manifest(
|
||||
plugin_id: str | None = None,
|
||||
category: str = "productivity",
|
||||
price_cents: int = 0,
|
||||
permissions: list[str] | None = None,
|
||||
) -> PluginManifest:
|
||||
pid = plugin_id or f"plugin-{uuid.uuid4().hex[:8]}"
|
||||
return PluginManifest(
|
||||
id=pid,
|
||||
name=f"Plugin {pid}",
|
||||
description=f"Description for {pid}",
|
||||
version="1.0.0",
|
||||
author="test-author",
|
||||
permissions=permissions or ["read:tasks"],
|
||||
category=category,
|
||||
price_cents=price_cents,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PluginRegistry (DB-backed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPluginRegistry:
|
||||
"""Each test uses the conftest db_session fixture with a fresh in-memory DB."""
|
||||
|
||||
@pytest.fixture
|
||||
def reg(self) -> PluginRegistry:
|
||||
return PluginRegistry()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_seed_plugins_are_listed(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
result = await reg.list_plugins(db_session)
|
||||
assert result.total == 3
|
||||
assert all(p.id.startswith("plugin-") for p in result.plugins)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_approved_only(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(db_session, manifest, "plugins/key.zip")
|
||||
result = await reg.list_plugins(db_session)
|
||||
ids = [p.id for p in result.plugins]
|
||||
assert manifest.id not in ids # still pending
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_filter_by_category(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
result = await reg.list_plugins(db_session, category="communication")
|
||||
assert result.total == 1
|
||||
assert result.plugins[0].id == "plugin-slack-notify"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_filter_by_query(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
result = await reg.list_plugins(db_session, query="time")
|
||||
assert result.total == 1
|
||||
assert result.plugins[0].id == "plugin-time-tracker"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sort_by_installs(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
await reg.record_install(db_session, "plugin-slack-notify")
|
||||
await reg.record_install(db_session, "plugin-slack-notify")
|
||||
result = await reg.list_plugins(db_session, sort="installs")
|
||||
assert result.plugins[0].id == "plugin-slack-notify"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_plugin_found(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||
assert entry is not None
|
||||
assert entry["manifest"].id == "plugin-github-sync"
|
||||
assert "install_count" in entry
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_plugin_not_found(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession
|
||||
) -> None:
|
||||
entry = await reg.get_plugin(db_session, "no-such-plugin")
|
||||
assert entry is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_sets_pending(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession
|
||||
) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
plugin_id = await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||
assert plugin_id == manifest.id
|
||||
result = await db_session.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
row = result.scalar_one()
|
||||
assert row.status == "pending_review"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_makes_visible(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession
|
||||
) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||
await reg.approve_plugin(db_session, manifest.id)
|
||||
result = await reg.list_plugins(db_session)
|
||||
assert manifest.id in [p.id for p in result.plugins]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reject_stores_reason(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession
|
||||
) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||
await reg.reject_plugin(db_session, manifest.id, reason="Unsafe permissions")
|
||||
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
|
||||
row = result.scalar_one()
|
||||
assert row.status == "rejected"
|
||||
assert row.rejection_reason == "Unsafe permissions"
|
||||
listed = await reg.list_plugins(db_session)
|
||||
assert manifest.id not in [p.id for p in listed.plugins]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_unknown_raises_key_error(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession
|
||||
) -> None:
|
||||
with pytest.raises(KeyError):
|
||||
await reg.approve_plugin(db_session, "ghost-plugin")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_install_increments_count(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
await reg.record_install(db_session, "plugin-github-sync")
|
||||
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||
assert entry is not None
|
||||
assert entry["install_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_uninstall_decrements_count(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
await reg.record_install(db_session, "plugin-github-sync")
|
||||
await reg.record_install(db_session, "plugin-github-sync")
|
||||
await reg.record_uninstall(db_session, "plugin-github-sync")
|
||||
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||
assert entry is not None
|
||||
assert entry["install_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_uninstall_floors_at_zero(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
await reg.record_uninstall(db_session, "plugin-github-sync")
|
||||
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||
assert entry is not None
|
||||
assert entry["install_count"] == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ReviewQueue (DB-backed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReviewQueue:
|
||||
@pytest.fixture
|
||||
def reg(self) -> PluginRegistry:
|
||||
return PluginRegistry()
|
||||
|
||||
@pytest.fixture
|
||||
def queue(self) -> ReviewQueue:
|
||||
return ReviewQueue()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pending_returns_submitted_plugins(
|
||||
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
|
||||
) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||
pending = await queue.get_pending(db_session)
|
||||
assert any(p["plugin_id"] == manifest.id for p in pending)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_review_approved(
|
||||
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
|
||||
) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||
await queue.submit_review(db_session, manifest.id, TEST_USER_IDS["power"], "approved", "Looks good")
|
||||
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
|
||||
row = result.scalar_one()
|
||||
assert row.status == "approved"
|
||||
# Check review row was persisted
|
||||
review_result = await db_session.execute(
|
||||
select(PluginReviewModel).where(PluginReviewModel.plugin_id == manifest.id)
|
||||
)
|
||||
review = review_result.scalar_one()
|
||||
assert review.decision == "approved"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_review_rejected(
|
||||
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
|
||||
) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||
await queue.submit_review(
|
||||
db_session, manifest.id, TEST_USER_IDS["power"], "rejected", "Bad permissions"
|
||||
)
|
||||
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
|
||||
row = result.scalar_one()
|
||||
assert row.status == "rejected"
|
||||
|
||||
def test_validate_manifest_ok(self) -> None:
|
||||
manifest = _fresh_manifest(permissions=["read:tasks", "write:notes"])
|
||||
validate_manifest(manifest) # should not raise
|
||||
|
||||
def test_validate_manifest_unknown_permission(self) -> None:
|
||||
manifest = _fresh_manifest(permissions=["read:tasks", "read:secrets"])
|
||||
with pytest.raises(ValueError, match="Unknown permission"):
|
||||
validate_manifest(manifest)
|
||||
|
||||
def test_validate_manifest_invalid_id_format(self) -> None:
|
||||
manifest = _fresh_manifest(plugin_id="Plugin_ID_Invalid")
|
||||
with pytest.raises(ValueError, match="Invalid plugin id format"):
|
||||
validate_manifest(manifest)
|
||||
|
||||
def test_validate_manifest_id_with_uppercase(self) -> None:
|
||||
manifest = _fresh_manifest(plugin_id="UpperCase")
|
||||
with pytest.raises(ValueError, match="Invalid plugin id format"):
|
||||
validate_manifest(manifest)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RevenueShare (DB-backed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRevenueShare:
|
||||
@pytest.fixture
|
||||
def rs(self) -> RevenueShare:
|
||||
return RevenueShare()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_install_free_plugin(
|
||||
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
|
||||
result = await db_session.execute(
|
||||
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-github-sync")
|
||||
)
|
||||
event = result.scalar_one()
|
||||
assert event.developer_share_cents == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_install_paid_plugin_no_stripe(
|
||||
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
await rs.record_install(
|
||||
db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499
|
||||
)
|
||||
result = await db_session.execute(
|
||||
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-slack-notify")
|
||||
)
|
||||
event = result.scalar_one()
|
||||
assert event.amount_cents == 499
|
||||
assert event.developer_share_cents == int(499 * 0.70)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_install_increments_registry_count(
|
||||
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
reg = PluginRegistry()
|
||||
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
|
||||
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||
assert entry is not None
|
||||
assert entry["install_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_earnings_empty(
|
||||
self, rs: RevenueShare, db_session: AsyncSession
|
||||
) -> None:
|
||||
result = await rs.get_earnings(db_session, "unknown-dev")
|
||||
assert result["total_installs"] == 0
|
||||
assert result["total_revenue_cents"] == 0
|
||||
assert result["developer_share_cents"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_earnings_aggregates(
|
||||
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["power"], amount_cents=499)
|
||||
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499)
|
||||
result = await rs.get_earnings(db_session, "Adiuva")
|
||||
assert result["total_installs"] == 2
|
||||
assert result["total_revenue_cents"] == 998
|
||||
assert result["developer_share_cents"] == int(499 * 0.70) * 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Route integration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPluginRoutes:
|
||||
def test_list_plugins_requires_power_tier(self, client, seed_plugins) -> None:
|
||||
resp = client.get("/api/v1/plugins", headers=auth_header("free"))
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_list_plugins_pro_tier_blocked(self, client, seed_plugins) -> None:
|
||||
resp = client.get("/api/v1/plugins", headers=auth_header("pro"))
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_list_plugins_power_tier_ok(self, client, seed_plugins) -> None:
|
||||
resp = client.get("/api/v1/plugins", headers=auth_header("power"))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "plugins" in data
|
||||
assert data["total"] == 3
|
||||
|
||||
def test_list_plugins_team_tier_ok(self, client, seed_plugins) -> None:
|
||||
resp = client.get("/api/v1/plugins", headers=auth_header("team"))
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_get_plugin_found(self, client, seed_plugins) -> None:
|
||||
resp = client.get("/api/v1/plugins/plugin-github-sync", headers=auth_header())
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["plugin"]["id"] == "plugin-github-sync"
|
||||
assert "install_count" in data
|
||||
|
||||
def test_get_plugin_not_found(self, client, seed_plugins) -> None:
|
||||
resp = client.get("/api/v1/plugins/no-such-plugin", headers=auth_header())
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_install_plugin_free(self, client, seed_plugins) -> None:
|
||||
resp = client.post(
|
||||
"/api/v1/plugins/plugin-github-sync/install",
|
||||
json={"plugin_id": "plugin-github-sync"},
|
||||
headers=auth_header(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is True
|
||||
assert "download_url" in data
|
||||
|
||||
def test_install_plugin_not_found(self, client, seed_plugins) -> None:
|
||||
resp = client.post(
|
||||
"/api/v1/plugins/ghost/install",
|
||||
json={"plugin_id": "ghost"},
|
||||
headers=auth_header(),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_uninstall_plugin_ok(self, client, seed_plugins) -> None:
|
||||
resp = client.delete(
|
||||
"/api/v1/plugins/plugin-github-sync/install",
|
||||
headers=auth_header(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["ok"] is True
|
||||
|
||||
def test_install_requires_power_tier(self, client, seed_plugins) -> None:
|
||||
resp = client.post(
|
||||
"/api/v1/plugins/plugin-github-sync/install",
|
||||
json={"plugin_id": "plugin-github-sync"},
|
||||
headers=auth_header("free"),
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
@@ -1,562 +0,0 @@
|
||||
"""Tests for the storage layer: encryption, BlobStore, VectorStore, and storage routes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from app.storage.encryption import reject_if_tampered, verify_checksum
|
||||
from app.storage.blob_store import BlobStore
|
||||
from app.storage.vector_store import VectorStore, _blob_to_vector
|
||||
from app.schemas import VectorItem, VectorSearchResult
|
||||
from tests.conftest import auth_header, S3_TEST_BUCKET
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
_BLOB = b"encrypted-payload-opaque-to-server"
|
||||
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
|
||||
_BUCKET = S3_TEST_BUCKET
|
||||
_REGION = "us-east-1"
|
||||
|
||||
|
||||
def _pinecone_mock():
|
||||
"""Return a mock Pinecone index with realistic return shapes."""
|
||||
mock_index = MagicMock()
|
||||
mock_index.query.return_value = {
|
||||
"matches": [
|
||||
{
|
||||
"id": "v1",
|
||||
"score": 0.95,
|
||||
"metadata": {
|
||||
"blob": base64.b64encode(b"result-blob").decode(),
|
||||
"checksum": hashlib.sha256(b"result-blob").hexdigest(),
|
||||
"user_id": "u1",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_pc = MagicMock()
|
||||
mock_pc.return_value.Index.return_value = mock_index
|
||||
return mock_pc, mock_index
|
||||
|
||||
|
||||
# ── TestEncryption ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEncryption:
|
||||
def test_verify_checksum_correct(self) -> None:
|
||||
assert verify_checksum(_BLOB, _CHECKSUM) is True
|
||||
|
||||
def test_verify_checksum_wrong(self) -> None:
|
||||
assert verify_checksum(_BLOB, "0" * 64) is False
|
||||
|
||||
def test_verify_checksum_empty_checksum(self) -> None:
|
||||
assert verify_checksum(_BLOB, "") is False
|
||||
|
||||
def test_verify_checksum_empty_blob(self) -> None:
|
||||
expected = hashlib.sha256(b"").hexdigest()
|
||||
assert verify_checksum(b"", expected) is True
|
||||
|
||||
def test_verify_checksum_tampered_blob(self) -> None:
|
||||
tampered = _BLOB + b"\x00"
|
||||
assert verify_checksum(tampered, _CHECKSUM) is False
|
||||
|
||||
def test_reject_if_tampered_passes_when_valid(self) -> None:
|
||||
# Should not raise
|
||||
reject_if_tampered(_BLOB, _CHECKSUM)
|
||||
|
||||
def test_reject_if_tampered_raises_400_on_mismatch(self) -> None:
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
reject_if_tampered(_BLOB, "bad" * 20)
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
def test_reject_if_tampered_detail_mentions_checksum(self) -> None:
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
reject_if_tampered(_BLOB, "bad" * 20)
|
||||
assert "checksum" in exc_info.value.detail.lower()
|
||||
|
||||
def test_checksum_is_sha256_hex(self) -> None:
|
||||
cs = hashlib.sha256(_BLOB).hexdigest()
|
||||
assert len(cs) == 64
|
||||
assert all(c in "0123456789abcdef" for c in cs)
|
||||
|
||||
|
||||
# ── TestBlobStore ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBlobStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_returns_correct_key(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
key = await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
assert key == "u1/tasks/r1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_object_exists_in_s3(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
# Verify by downloading — no exception means object exists
|
||||
retrieved = await store.download("u1", "u1/tasks/r1")
|
||||
assert retrieved == _BLOB
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_retrieves_same_bytes(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "notes", "n1", b"note-data", hashlib.sha256(b"note-data").hexdigest())
|
||||
result = await store.download("u1", "u1/notes/n1")
|
||||
assert result == b"note-data"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_removes_object(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
await store.delete("u1", "u1/tasks/r1")
|
||||
with pytest.raises(ClientError) as exc_info:
|
||||
await store.download("u1", "u1/tasks/r1")
|
||||
assert exc_info.value.response["Error"]["Code"] == "NoSuchKey"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_is_idempotent(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
# Delete a key that never existed — should not raise
|
||||
await store.delete("u1", "u1/tasks/nonexistent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_keys_returns_correct_keys(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
await store.upload("u1", "tasks", "r2", _BLOB, _CHECKSUM)
|
||||
keys = await store.list_keys("u1", "tasks")
|
||||
assert set(keys) == {"u1/tasks/r1", "u1/tasks/r2"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_keys_scoped_to_table(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
await store.upload("u1", "notes", "n1", _BLOB, _CHECKSUM)
|
||||
keys = await store.list_keys("u1", "tasks")
|
||||
assert "u1/notes/n1" not in keys
|
||||
assert "u1/tasks/r1" in keys
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_keys_no_cross_user_leakage(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
await store.upload("u2", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
keys_u1 = await store.list_keys("u1", "tasks")
|
||||
assert "u2/tasks/r1" not in keys_u1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_keys_empty_table(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
keys = await store.list_keys("u1", "tasks")
|
||||
assert keys == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_uses_sse_s3_encryption(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
# Verify S3 metadata was set — check via head_object
|
||||
with patch("app.storage.blob_store.settings") as mock_settings:
|
||||
mock_settings.S3_BUCKET = _BUCKET
|
||||
mock_settings.S3_REGION = _REGION
|
||||
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
||||
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
||||
client = boto3.client("s3", region_name=_REGION)
|
||||
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
|
||||
assert response.get("ServerSideEncryption") == "AES256"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_stores_checksum_in_metadata(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
client = boto3.client("s3", region_name=_REGION)
|
||||
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
|
||||
assert response["Metadata"]["checksum"] == _CHECKSUM
|
||||
|
||||
|
||||
# ── _blob_to_vector helper ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBlobToVector:
|
||||
def test_returns_32_floats(self) -> None:
|
||||
v = _blob_to_vector(b"test")
|
||||
assert len(v) == 32
|
||||
|
||||
def test_all_values_in_range(self) -> None:
|
||||
v = _blob_to_vector(b"test")
|
||||
assert all(-1.0 <= x <= 1.0 for x in v)
|
||||
|
||||
def test_deterministic(self) -> None:
|
||||
assert _blob_to_vector(b"same") == _blob_to_vector(b"same")
|
||||
|
||||
def test_different_blobs_different_vectors(self) -> None:
|
||||
assert _blob_to_vector(b"aaa") != _blob_to_vector(b"bbb")
|
||||
|
||||
|
||||
# ── TestVectorStorePinecone ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestVectorStorePinecone:
|
||||
def _store(self) -> VectorStore:
|
||||
store = VectorStore()
|
||||
store._use_pinecone = lambda: True # type: ignore[method-assign]
|
||||
return store
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_calls_index_upsert(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
items = [VectorItem(id="v1", blob=b"enc-blob", checksum=hashlib.sha256(b"enc-blob").hexdigest())]
|
||||
await store.upsert("u1", items)
|
||||
mock_index.upsert.assert_called_once()
|
||||
call_kwargs = mock_index.upsert.call_args[1]
|
||||
assert call_kwargs.get("namespace") == "u1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_encodes_blob_as_base64_in_metadata(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
items = [VectorItem(id="v1", blob=b"secret", checksum=hashlib.sha256(b"secret").hexdigest())]
|
||||
await store.upsert("u1", items)
|
||||
vectors_arg = mock_index.upsert.call_args[1]["vectors"]
|
||||
assert vectors_arg[0]["metadata"]["blob"] == base64.b64encode(b"secret").decode()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_calls_index_query(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
await store.search("u1", b"query-blob", top_k=5)
|
||||
mock_index.query.assert_called_once()
|
||||
query_kwargs = mock_index.query.call_args[1]
|
||||
assert query_kwargs.get("namespace") == "u1"
|
||||
assert query_kwargs.get("top_k") == 5
|
||||
assert query_kwargs.get("include_metadata") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_returns_vector_search_results(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
results = await store.search("u1", b"query", top_k=10)
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], VectorSearchResult)
|
||||
assert results[0].id == "v1"
|
||||
assert results[0].score == 0.95
|
||||
assert results[0].blob == b"result-blob"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_uses_derived_query_vector(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
await store.search("u1", b"query-blob", top_k=3)
|
||||
expected_vector = _blob_to_vector(b"query-blob")
|
||||
actual_vector = mock_index.query.call_args[1].get("vector")
|
||||
assert actual_vector == expected_vector
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_calls_index_delete(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
await store.delete("u1", ["v1", "v2"])
|
||||
mock_index.delete.assert_called_once()
|
||||
delete_kwargs = mock_index.delete.call_args[1]
|
||||
assert delete_kwargs.get("namespace") == "u1"
|
||||
assert set(delete_kwargs.get("ids", [])) == {"v1", "v2"}
|
||||
|
||||
|
||||
# ── TestVectorStoreQdrant ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestVectorStoreQdrant:
|
||||
def _store(self) -> VectorStore:
|
||||
store = VectorStore()
|
||||
store._use_pinecone = lambda: False # type: ignore[method-assign]
|
||||
return store
|
||||
|
||||
def _qdrant_mock(self) -> MagicMock:
|
||||
mock_hit = MagicMock()
|
||||
mock_hit.id = "v1"
|
||||
mock_hit.score = 0.88
|
||||
mock_hit.payload = {
|
||||
"blob": base64.b64encode(b"qdrant-result").decode(),
|
||||
"user_id": "u1",
|
||||
}
|
||||
mock_client = MagicMock()
|
||||
mock_client.search.return_value = [mock_hit]
|
||||
return mock_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_calls_client_upsert(self) -> None:
|
||||
mock_client = MagicMock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
|
||||
await store.upsert("u1", items)
|
||||
mock_client.upsert.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_uses_correct_collection(self) -> None:
|
||||
mock_client = MagicMock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
|
||||
await store.upsert("u1", items)
|
||||
call_kwargs = mock_client.upsert.call_args[1]
|
||||
assert call_kwargs["collection_name"] == "adiuva_vectors"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_calls_client_search(self) -> None:
|
||||
mock_client = self._qdrant_mock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
await store.search("u1", b"query", top_k=5)
|
||||
mock_client.search.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_passes_limit(self) -> None:
|
||||
mock_client = self._qdrant_mock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
await store.search("u1", b"query", top_k=7)
|
||||
call_kwargs = mock_client.search.call_args[1]
|
||||
assert call_kwargs.get("limit") == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_returns_vector_search_results(self) -> None:
|
||||
mock_client = self._qdrant_mock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
results = await store.search("u1", b"query", top_k=5)
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], VectorSearchResult)
|
||||
assert results[0].id == "v1"
|
||||
assert results[0].score == 0.88
|
||||
assert results[0].blob == b"qdrant-result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_calls_client_delete(self) -> None:
|
||||
mock_client = MagicMock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
await store.delete("u1", ["v1", "v2"])
|
||||
mock_client.delete.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_uses_correct_collection(self) -> None:
|
||||
mock_client = MagicMock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
await store.delete("u1", ["v1"])
|
||||
call_kwargs = mock_client.delete.call_args[1]
|
||||
assert call_kwargs["collection_name"] == "adiuva_vectors"
|
||||
|
||||
|
||||
# ── TestStorageRoutes (integration) ───────────────────────────────────
|
||||
|
||||
|
||||
class TestStorageRoutes:
|
||||
"""Integration tests for POST/GET/PUT/DELETE /api/v1/storage/records.
|
||||
|
||||
Pydantic v2 converts JSON string → bytes via ``str.encode('utf-8')``.
|
||||
So "hello" in JSON becomes ``b"hello"`` on the server. We use plain
|
||||
ASCII strings as blob values and compute checksums accordingly.
|
||||
"""
|
||||
|
||||
_BLOB_STR = "encrypted-payload-opaque-to-server"
|
||||
_BLOB_BYTES = _BLOB_STR.encode()
|
||||
_BLOB_CHECKSUM = hashlib.sha256(_BLOB_BYTES).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def _create_payload(cls, blob_str: str | None = None) -> dict:
|
||||
blob_str = blob_str or cls._BLOB_STR
|
||||
checksum = hashlib.sha256(blob_str.encode()).hexdigest()
|
||||
return {
|
||||
"table": "tasks",
|
||||
"blob": blob_str,
|
||||
"checksum": checksum,
|
||||
}
|
||||
|
||||
def _create_record(self, client, tier="power", blob_str=None):
|
||||
payload = self._create_payload(blob_str)
|
||||
return client.post(
|
||||
"/api/v1/storage/records",
|
||||
json=payload,
|
||||
headers=auth_header(tier),
|
||||
)
|
||||
|
||||
# ── Create ────────────────────────────────────────────────────────
|
||||
|
||||
def test_create_record(self, client, s3_bucket) -> None:
|
||||
resp = self._create_record(client)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert "id" in data
|
||||
assert "created_at" in data
|
||||
|
||||
def test_create_record_bad_checksum(self, client, s3_bucket) -> None:
|
||||
payload = {
|
||||
"table": "tasks",
|
||||
"blob": self._BLOB_STR,
|
||||
"checksum": "0" * 64,
|
||||
}
|
||||
resp = client.post(
|
||||
"/api/v1/storage/records",
|
||||
json=payload,
|
||||
headers=auth_header("power"),
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_create_record_free_tier_blocked(self, client, s3_bucket) -> None:
|
||||
"""Free tier has cloud_storage_gb=0 → 402."""
|
||||
resp = self._create_record(client, tier="free")
|
||||
assert resp.status_code == 402
|
||||
|
||||
def test_create_record_pro_tier_allowed(self, client, s3_bucket) -> None:
|
||||
"""Pro tier has cloud_storage_gb=5 → succeeds for small blob."""
|
||||
resp = self._create_record(client, tier="pro")
|
||||
assert resp.status_code == 201
|
||||
|
||||
# ── List ──────────────────────────────────────────────────────────
|
||||
|
||||
def test_list_records(self, client, s3_bucket) -> None:
|
||||
self._create_record(client)
|
||||
self._create_record(client, blob_str="second-blob")
|
||||
resp = client.get(
|
||||
"/api/v1/storage/records",
|
||||
headers=auth_header("power"),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 2
|
||||
# Each entry has metadata, no blob bytes
|
||||
for item in data:
|
||||
assert "id" in item
|
||||
assert "table" in item
|
||||
assert "checksum" in item
|
||||
assert "blob" not in item
|
||||
|
||||
def test_list_records_filter_by_table(self, client, s3_bucket) -> None:
|
||||
self._create_record(client)
|
||||
# Create in a different table
|
||||
note_blob = "note-blob"
|
||||
payload = {
|
||||
"table": "notes",
|
||||
"blob": note_blob,
|
||||
"checksum": hashlib.sha256(note_blob.encode()).hexdigest(),
|
||||
}
|
||||
client.post(
|
||||
"/api/v1/storage/records",
|
||||
json=payload,
|
||||
headers=auth_header("power"),
|
||||
)
|
||||
resp = client.get(
|
||||
"/api/v1/storage/records?table=notes",
|
||||
headers=auth_header("power"),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["table"] == "notes"
|
||||
|
||||
def test_list_records_isolated_per_user(self, client, s3_bucket) -> None:
|
||||
"""One user's records should not appear in another user's list."""
|
||||
self._create_record(client, tier="power")
|
||||
resp = client.get(
|
||||
"/api/v1/storage/records",
|
||||
headers=auth_header("team"),
|
||||
)
|
||||
assert resp.json() == []
|
||||
|
||||
# ── Download ──────────────────────────────────────────────────────
|
||||
|
||||
def test_download_record(self, client, s3_bucket) -> None:
|
||||
create_resp = self._create_record(client)
|
||||
record_id = create_resp.json()["id"]
|
||||
resp = client.get(
|
||||
f"/api/v1/storage/records/{record_id}",
|
||||
headers=auth_header("power"),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.content == self._BLOB_BYTES
|
||||
assert resp.headers["X-Checksum"] == self._BLOB_CHECKSUM
|
||||
|
||||
def test_download_record_not_found(self, client, s3_bucket) -> None:
|
||||
resp = client.get(
|
||||
"/api/v1/storage/records/nonexistent-id",
|
||||
headers=auth_header("power"),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
# ── Update ────────────────────────────────────────────────────────
|
||||
|
||||
def test_update_record(self, client, s3_bucket) -> None:
|
||||
create_resp = self._create_record(client)
|
||||
record_id = create_resp.json()["id"]
|
||||
new_blob_str = "updated-encrypted-payload"
|
||||
new_checksum = hashlib.sha256(new_blob_str.encode()).hexdigest()
|
||||
resp = client.put(
|
||||
f"/api/v1/storage/records/{record_id}",
|
||||
json={"blob": new_blob_str, "checksum": new_checksum},
|
||||
headers=auth_header("power"),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"ok": True}
|
||||
|
||||
# Verify download returns the updated blob
|
||||
dl = client.get(
|
||||
f"/api/v1/storage/records/{record_id}",
|
||||
headers=auth_header("power"),
|
||||
)
|
||||
assert dl.content == new_blob_str.encode()
|
||||
|
||||
def test_update_record_bad_checksum(self, client, s3_bucket) -> None:
|
||||
create_resp = self._create_record(client)
|
||||
record_id = create_resp.json()["id"]
|
||||
resp = client.put(
|
||||
f"/api/v1/storage/records/{record_id}",
|
||||
json={"blob": "some-data", "checksum": "0" * 64},
|
||||
headers=auth_header("power"),
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
# ── Delete ────────────────────────────────────────────────────────
|
||||
|
||||
def test_delete_record(self, client, s3_bucket) -> None:
|
||||
create_resp = self._create_record(client)
|
||||
record_id = create_resp.json()["id"]
|
||||
resp = client.delete(
|
||||
f"/api/v1/storage/records/{record_id}",
|
||||
headers=auth_header("power"),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"ok": True}
|
||||
|
||||
# Subsequent GET should return 404
|
||||
dl = client.get(
|
||||
f"/api/v1/storage/records/{record_id}",
|
||||
headers=auth_header("power"),
|
||||
)
|
||||
assert dl.status_code == 404
|
||||
|
||||
def test_delete_record_not_found(self, client, s3_bucket) -> None:
|
||||
resp = client.delete(
|
||||
"/api/v1/storage/records/nonexistent",
|
||||
headers=auth_header("power"),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
Reference in New Issue
Block a user