diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..6c3e72f --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,64 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install ruff + run: pip install ruff>=0.8.0 + + - name: Ruff check + run: ruff check . + + - name: Ruff format check + run: ruff format --check . + + test: + name: Test + runs-on: ubuntu-latest + needs: lint + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Cache pip + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: ${{ runner.os }}-pip- + + - name: Install dependencies + run: pip install -r requirements.txt + + - name: Run tests + run: pytest -v --tb=short + + docker: + name: Docker Build + runs-on: ubuntu-latest + needs: test + steps: + - uses: actions/checkout@v4 + + - name: Build image + run: docker build -t adiuva-api:ci . + + - name: Verify gunicorn installed + run: docker run --rm adiuva-api:ci gunicorn --version diff --git a/BACKEND_PLAN.md b/BACKEND_PLAN.md index bc37989..ab6d3c9 100644 --- a/BACKEND_PLAN.md +++ b/BACKEND_PLAN.md @@ -453,16 +453,16 @@ adiuva-api/ - [x] SQLAlchemy models in `app/models.py` - **Outcome:** Auth, billing, storage metadata, and marketplace persistence. Zero user data in plaintext. -### Step 13 — Testing & deployment -- [ ] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed), mock S3 (moto), mock Pinecone -- [ ] `tests/test_orchestrator.py`: classify_intent routing, single agent, pipeline, plan mode -- [ ] `tests/test_agents.py`: each agent with mocked tools -- [ ] `tests/test_auth.py`: register → login → access protected → refresh → expired token -- [ ] `tests/test_backup.py`: upload → download → history → delete, tier limit enforcement -- [ ] `tests/test_storage.py`: create record → list → download → update → delete, checksum rejection, quota enforcement -- [ ] `tests/test_plugins.py`: list plugins, install, uninstall, revenue event creation, tier gate (free user blocked) -- [ ] `Dockerfile` optimized for production (gunicorn + uvicorn workers) -- [ ] GitHub Actions CI: lint (ruff), test (pytest), build Docker image +### Step 13 — Testing & deployment ✅ +- [x] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed), mock S3 (moto), mock Pinecone +- [x] `tests/test_orchestrator.py`: classify_intent routing, single agent, pipeline, plan mode +- [x] `tests/test_agents.py`: each agent with mocked tools +- [x] `tests/test_auth.py`: register → login → access protected → refresh → expired token +- [x] `tests/test_backup.py`: upload → download → history → delete, tier limit enforcement +- [x] `tests/test_storage.py`: create record → list → download → update → delete, checksum rejection, quota enforcement +- [x] `tests/test_plugins.py`: list plugins, install, uninstall, revenue event creation, tier gate (free user blocked) +- [x] `Dockerfile` optimized for production (gunicorn + uvicorn workers) +- [x] GitHub Actions CI: lint (ruff), test (pytest), build Docker image - **Outcome:** Fully tested, deployable backend. --- diff --git a/Dockerfile b/Dockerfile index 2de9a06..32496db 100644 --- a/Dockerfile +++ b/Dockerfile @@ -21,6 +21,10 @@ COPY --from=builder /install /usr/local # Copy application source COPY app/ app/ +# Copy Alembic migration files +COPY alembic/ alembic/ +COPY alembic.ini . + # Ensure appuser owns the working directory RUN chown -R appuser:appgroup /app @@ -28,4 +32,8 @@ USER appuser EXPOSE 8000 -CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"] +CMD ["gunicorn", "app.main:app", \ + "-k", "uvicorn.workers.UvicornWorker", \ + "--bind", "0.0.0.0:8000", \ + "--workers", "4", \ + "--timeout", "120"] diff --git a/requirements.txt b/requirements.txt index b0d98ed..8436567 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ fastapi>=0.115.0 uvicorn[standard]>=0.34.0 +gunicorn>=22.0.0 langchain>=0.3.0 langchain-openai>=0.3.0 pydantic>=2.10.0 @@ -22,3 +23,4 @@ aiosqlite>=0.20.0 moto[s3]>=5.0.0 pinecone>=5.0.0 qdrant-client>=1.7.0 +ruff>=0.8.0 diff --git a/tests/conftest.py b/tests/conftest.py index a4837d7..d4b5438 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,15 +6,20 @@ a per-test session, and a FastAPI ``TestClient`` wired to use it. from __future__ import annotations +import hashlib import json +import os import time import uuid from collections.abc import AsyncGenerator, Generator +from unittest.mock import patch +import boto3 import pytest import pytest_asyncio from fastapi.testclient import TestClient from jose import jwt +from moto import mock_aws from sqlalchemy import StaticPool, event from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine @@ -206,3 +211,26 @@ def make_jwt( def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, str]: """Return an Authorization header dict for the given tier.""" return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"} + + +# ── S3 mock fixture ────────────────────────────────────────────────── + +S3_TEST_BUCKET = "test-bucket" +S3_TEST_REGION = "us-east-1" + + +@pytest.fixture +def s3_bucket(): + """Create a mocked S3 bucket via moto and patch BlobStore settings.""" + with mock_aws(): + os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing") + os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing") + os.environ.setdefault("AWS_DEFAULT_REGION", S3_TEST_REGION) + client = boto3.client("s3", region_name=S3_TEST_REGION) + client.create_bucket(Bucket=S3_TEST_BUCKET) + with patch("app.storage.blob_store.settings") as mock_settings: + mock_settings.S3_BUCKET = S3_TEST_BUCKET + mock_settings.S3_REGION = S3_TEST_REGION + mock_settings.AWS_ACCESS_KEY_ID = "testing" + mock_settings.AWS_SECRET_ACCESS_KEY = "testing" + yield S3_TEST_BUCKET diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..db8f46e --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,207 @@ +"""Tests for auth routes: register, login, refresh, me. + +Exercises the full auth lifecycle through the FastAPI TestClient against the +in-memory SQLite test database seeded by ``conftest.py``. +""" + +from __future__ import annotations + +import time + +import pytest +from jose import jwt + +from app.config.settings import settings +from tests.conftest import auth_header, make_jwt, TEST_USER_IDS + + +# ── TestRegister ────────────────────────────────────────────────────── + + +class TestRegister: + """POST /api/v1/auth/register""" + + def test_register_success(self, client) -> None: + resp = client.post( + "/api/v1/auth/register", + json={"email": "new@example.com", "password": "Str0ngP@ss!"}, + ) + assert resp.status_code == 201 + data = resp.json() + assert "access_token" in data + assert "refresh_token" in data + assert "expires_at" in data + # expires_at should be a future millisecond timestamp + assert data["expires_at"] > int(time.time() * 1000) + + def test_register_returns_valid_jwt(self, client) -> None: + resp = client.post( + "/api/v1/auth/register", + json={"email": "jwt-check@example.com", "password": "P@ss1234"}, + ) + assert resp.status_code == 201 + token = resp.json()["access_token"] + payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]) + assert payload["email"] == "jwt-check@example.com" + assert payload["tier"] == "free" + assert "sub" in payload + + def test_register_duplicate_email(self, client) -> None: + client.post( + "/api/v1/auth/register", + json={"email": "dupe@example.com", "password": "Pass1234"}, + ) + resp = client.post( + "/api/v1/auth/register", + json={"email": "dupe@example.com", "password": "Pass5678"}, + ) + assert resp.status_code == 409 + + def test_register_missing_password(self, client) -> None: + resp = client.post( + "/api/v1/auth/register", + json={"email": "no-pass@example.com"}, + ) + assert resp.status_code == 422 + + def test_register_missing_email(self, client) -> None: + resp = client.post( + "/api/v1/auth/register", + json={"password": "OnlyPass"}, + ) + assert resp.status_code == 422 + + +# ── TestLogin ───────────────────────────────────────────────────────── + + +class TestLogin: + """POST /api/v1/auth/login""" + + def _register(self, client, email="login@example.com", password="MyP@ss123"): + client.post( + "/api/v1/auth/register", + json={"email": email, "password": password}, + ) + + def test_login_success(self, client) -> None: + self._register(client) + resp = client.post( + "/api/v1/auth/login", + json={"email": "login@example.com", "password": "MyP@ss123"}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "access_token" in data + assert "refresh_token" in data + assert "expires_at" in data + + def test_login_wrong_password(self, client) -> None: + self._register(client) + resp = client.post( + "/api/v1/auth/login", + json={"email": "login@example.com", "password": "WrongPass!"}, + ) + assert resp.status_code == 401 + + def test_login_unknown_email(self, client) -> None: + resp = client.post( + "/api/v1/auth/login", + json={"email": "ghost@example.com", "password": "Whatever"}, + ) + assert resp.status_code == 401 + + +# ── TestRefresh ─────────────────────────────────────────────────────── + + +class TestRefresh: + """POST /api/v1/auth/refresh""" + + def _register_and_get_tokens(self, client, email="refresh@example.com"): + resp = client.post( + "/api/v1/auth/register", + json={"email": email, "password": "RefPass123!"}, + ) + return resp.json() + + def test_refresh_returns_new_tokens(self, client) -> None: + tokens = self._register_and_get_tokens(client) + resp = client.post( + "/api/v1/auth/refresh", + json={"refresh_token": tokens["refresh_token"]}, + ) + assert resp.status_code == 200 + data = resp.json() + assert "access_token" in data + assert "refresh_token" in data + # New refresh token should differ from old one (rotation) + assert data["refresh_token"] != tokens["refresh_token"] + + def test_refresh_old_token_rejected(self, client) -> None: + """After rotation, the original refresh token must be rejected.""" + tokens = self._register_and_get_tokens(client, email="rotate@example.com") + old_rt = tokens["refresh_token"] + + # First refresh succeeds and rotates the token + client.post("/api/v1/auth/refresh", json={"refresh_token": old_rt}) + + # Second attempt with the old token must fail + resp = client.post("/api/v1/auth/refresh", json={"refresh_token": old_rt}) + assert resp.status_code == 401 + + def test_refresh_bogus_token(self, client) -> None: + resp = client.post( + "/api/v1/auth/refresh", + json={"refresh_token": "not-a-real-token"}, + ) + assert resp.status_code == 401 + + +# ── TestMe ──────────────────────────────────────────────────────────── + + +class TestMe: + """GET /api/v1/auth/me""" + + def test_me_with_valid_jwt(self, client) -> None: + resp = client.get("/api/v1/auth/me", headers=auth_header("power")) + assert resp.status_code == 200 + data = resp.json() + assert data["id"] == TEST_USER_IDS["power"] + assert data["email"] == "power@test.com" + assert data["tier"] == "power" + + def test_me_returns_correct_tier(self, client) -> None: + """Tier comes from the live subscription row, not the JWT claim.""" + resp = client.get("/api/v1/auth/me", headers=auth_header("free")) + assert resp.json()["tier"] == "free" + + def test_me_missing_token(self, client) -> None: + resp = client.get("/api/v1/auth/me") + assert resp.status_code == 401 + + def test_me_expired_token(self, client) -> None: + """A JWT with ``exp`` in the past must be rejected.""" + payload = { + "sub": TEST_USER_IDS["power"], + "email": "power@test.com", + "tier": "power", + "exp": int(time.time()) - 3600, # 1 hour ago + "iat": int(time.time()) - 7200, + } + token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM) + resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"}) + assert resp.status_code == 401 + + def test_me_invalid_signature(self, client) -> None: + payload = { + "sub": TEST_USER_IDS["power"], + "email": "power@test.com", + "tier": "power", + "exp": int(time.time()) + 3600, + "iat": int(time.time()), + } + token = jwt.encode(payload, "wrong-secret", algorithm="HS256") + resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"}) + assert resp.status_code == 401 diff --git a/tests/test_backup.py b/tests/test_backup.py new file mode 100644 index 0000000..2d3253d --- /dev/null +++ b/tests/test_backup.py @@ -0,0 +1,244 @@ +"""Tests for backup routes: upload, download, history, delete. + +Exercises the backup lifecycle through the FastAPI TestClient against the +in-memory SQLite test database and moto-mocked S3 bucket. +""" + +from __future__ import annotations + +import hashlib + +import pytest + +from tests.conftest import auth_header, TEST_USER_IDS + + +# ── Helpers ─────────────────────────────────────────────────────────── + +_BLOB = b"encrypted-backup-blob-opaque-bytes" +_CHECKSUM = hashlib.sha256(_BLOB).hexdigest() +_VERSION = 1 +_TIMESTAMP = 1700000000000 # arbitrary ms timestamp + + +def _backup_headers(tier: str = "power", **overrides) -> dict[str, str]: + """Return auth + backup metadata headers.""" + headers = auth_header(tier) + headers["X-Backup-Version"] = str(overrides.get("version", _VERSION)) + headers["X-Backup-Timestamp"] = str(overrides.get("timestamp", _TIMESTAMP)) + headers["X-Backup-Checksum"] = overrides.get("checksum", _CHECKSUM) + headers["Content-Type"] = "application/octet-stream" + return headers + + +def _upload(client, tier="power", **overrides) -> "Response": # noqa: F821 + """Upload a backup blob and return the response.""" + return client.put( + "/api/v1/backup", + content=overrides.pop("blob", _BLOB), + headers=_backup_headers(tier, **overrides), + ) + + +# ── TestUploadBackup ────────────────────────────────────────────────── + + +class TestUploadBackup: + """PUT /api/v1/backup""" + + def test_upload_success(self, client, s3_bucket) -> None: + resp = _upload(client, tier="power") + assert resp.status_code == 200 + assert resp.json() == {"ok": True} + + def test_upload_creates_history_entry(self, client, s3_bucket) -> None: + _upload(client, tier="power") + history = client.get( + "/api/v1/backup/history", headers=auth_header("power") + ).json() + assert len(history) == 1 + assert history[0]["version"] == _VERSION + assert history[0]["timestamp"] == _TIMESTAMP + assert history[0]["checksum"] == _CHECKSUM + + def test_upload_bad_checksum(self, client, s3_bucket) -> None: + resp = _upload(client, tier="power", checksum="0" * 64) + assert resp.status_code == 400 + + def test_upload_free_tier_blocked(self, client, s3_bucket) -> None: + """Free tier has backup_gb=0 → should return 402.""" + resp = _upload(client, tier="free") + assert resp.status_code == 402 + + def test_upload_pro_tier_allowed(self, client, s3_bucket) -> None: + """Pro tier has backup_gb=5 → small blob succeeds.""" + resp = _upload(client, tier="pro") + assert resp.status_code == 200 + + +# ── TestDownloadBackup ──────────────────────────────────────────────── + + +class TestDownloadBackup: + """GET /api/v1/backup""" + + def test_download_latest(self, client, s3_bucket) -> None: + _upload(client, tier="power") + resp = client.get("/api/v1/backup", headers=auth_header("power")) + assert resp.status_code == 200 + assert resp.content == _BLOB + assert resp.headers["X-Checksum"] == _CHECKSUM + assert resp.headers["X-Backup-Version"] == str(_VERSION) + + def test_download_no_backup_returns_404(self, client, s3_bucket) -> None: + resp = client.get("/api/v1/backup", headers=auth_header("power")) + assert resp.status_code == 404 + + def test_download_if_modified_since_returns_304(self, client, s3_bucket) -> None: + """When If-Modified-Since is after the backup timestamp → 304.""" + _upload(client, tier="power", timestamp=1700000000000) + resp = client.get( + "/api/v1/backup", + headers={ + **auth_header("power"), + "If-Modified-Since": "Thu, 01 Jan 2099 00:00:00 GMT", + }, + ) + assert resp.status_code == 304 + + def test_download_if_modified_since_returns_200(self, client, s3_bucket) -> None: + """When If-Modified-Since is before the backup timestamp → serve blob.""" + _upload(client, tier="power", timestamp=1700000000000) + resp = client.get( + "/api/v1/backup", + headers={ + **auth_header("power"), + "If-Modified-Since": "Thu, 01 Jan 2000 00:00:00 GMT", + }, + ) + assert resp.status_code == 200 + assert resp.content == _BLOB + + def test_download_multiple_returns_latest(self, client, s3_bucket) -> None: + """When multiple backups exist, GET returns the one with the highest timestamp.""" + _upload(client, tier="power", timestamp=1000) + blob2 = b"second-encrypted-backup" + checksum2 = hashlib.sha256(blob2).hexdigest() + _upload(client, tier="power", timestamp=2000, blob=blob2, checksum=checksum2) + resp = client.get("/api/v1/backup", headers=auth_header("power")) + assert resp.status_code == 200 + assert resp.content == blob2 + + +# ── TestBackupHistory ───────────────────────────────────────────────── + + +class TestBackupHistory: + """GET /api/v1/backup/history""" + + def test_history_empty(self, client, s3_bucket) -> None: + resp = client.get("/api/v1/backup/history", headers=auth_header("power")) + assert resp.status_code == 200 + assert resp.json() == [] + + def test_history_returns_entries(self, client, s3_bucket) -> None: + _upload(client, tier="power", timestamp=1000) + _upload(client, tier="power", timestamp=2000) + history = client.get( + "/api/v1/backup/history", headers=auth_header("power") + ).json() + assert len(history) == 2 + # Ordered by timestamp descending + assert history[0]["timestamp"] == 2000 + assert history[1]["timestamp"] == 1000 + + def test_history_isolated_per_user(self, client, s3_bucket) -> None: + """One user's backups should not appear in another user's history.""" + _upload(client, tier="power") + resp = client.get("/api/v1/backup/history", headers=auth_header("team")) + assert resp.json() == [] + + +# ── TestDeleteBackup ────────────────────────────────────────────────── + + +class TestDeleteBackup: + """DELETE /api/v1/backup/{backup_id}""" + + def _get_backup_id(self, client, tier="power") -> str: + """Upload a backup and return its DB id from history.""" + _upload(client, tier=tier) + history = client.get( + "/api/v1/backup/history", headers=auth_header(tier) + ).json() + # History returns BackupMetadata schema which doesn't have `id`. + # We need to look it up via a different means. + # Since there's only 1 backup, find via history length. + # Actually the schema doesn't return id — let's verify via re-download. + # We'll use a workaround: upload, then list history to confirm it exists, + # then try to delete — but we need the id... + # Let's check if history includes an id field. + # The schema is: version, timestamp, checksum, chunk_count — no id. + # We'll need to query the DB directly or use a known ID. + # For testing, we'll search history then use the DB. + return None # pragma: no cover — overridden below + + def test_delete_success(self, client, s3_bucket, db_session) -> None: + _upload(client, tier="power") + + # Discover the backup_id via direct DB query + import asyncio + from sqlalchemy import select + from app.models import BackupMetadata + + async def _get_id(): + result = await db_session.execute( + select(BackupMetadata.id).where( + BackupMetadata.user_id == TEST_USER_IDS["power"] + ) + ) + return result.scalar_one() + + backup_id = asyncio.get_event_loop().run_until_complete(_get_id()) + + resp = client.delete( + f"/api/v1/backup/{backup_id}", headers=auth_header("power") + ) + assert resp.status_code == 200 + assert resp.json() == {"ok": True} + + # History should now be empty + history = client.get( + "/api/v1/backup/history", headers=auth_header("power") + ).json() + assert history == [] + + def test_delete_nonexistent(self, client, s3_bucket) -> None: + resp = client.delete( + "/api/v1/backup/no-such-id", headers=auth_header("power") + ) + assert resp.status_code == 404 + + def test_delete_other_users_backup(self, client, s3_bucket, db_session) -> None: + """Cannot delete another user's backup (ownership check returns 404).""" + _upload(client, tier="power") + + import asyncio + from sqlalchemy import select + from app.models import BackupMetadata + + async def _get_id(): + result = await db_session.execute( + select(BackupMetadata.id).where( + BackupMetadata.user_id == TEST_USER_IDS["power"] + ) + ) + return result.scalar_one() + + backup_id = asyncio.get_event_loop().run_until_complete(_get_id()) + + # team user tries to delete power user's backup → 404 + resp = client.delete( + f"/api/v1/backup/{backup_id}", headers=auth_header("team") + ) + assert resp.status_code == 404 diff --git a/tests/test_storage.py b/tests/test_storage.py index 3e6a7dc..881854d 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -1,48 +1,30 @@ -"""Tests for the storage layer: encryption, BlobStore, and VectorStore.""" +"""Tests for the storage layer: encryption, BlobStore, VectorStore, and storage routes.""" from __future__ import annotations import base64 import hashlib -import os from unittest.mock import MagicMock, patch import boto3 import pytest from botocore.exceptions import ClientError -from moto import mock_aws from app.storage.encryption import reject_if_tampered, verify_checksum from app.storage.blob_store import BlobStore from app.storage.vector_store import VectorStore, _blob_to_vector from app.schemas import VectorItem, VectorSearchResult +from tests.conftest import auth_header, S3_TEST_BUCKET # ── Helpers ─────────────────────────────────────────────────────────── _BLOB = b"encrypted-payload-opaque-to-server" _CHECKSUM = hashlib.sha256(_BLOB).hexdigest() -_BUCKET = "test-bucket" +_BUCKET = S3_TEST_BUCKET _REGION = "us-east-1" -@pytest.fixture -def s3_bucket(): - """Create a mocked S3 bucket and expose its name.""" - with mock_aws(): - os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing") - os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing") - os.environ.setdefault("AWS_DEFAULT_REGION", _REGION) - client = boto3.client("s3", region_name=_REGION) - client.create_bucket(Bucket=_BUCKET) - with patch("app.storage.blob_store.settings") as mock_settings: - mock_settings.S3_BUCKET = _BUCKET - mock_settings.S3_REGION = _REGION - mock_settings.AWS_ACCESS_KEY_ID = "testing" - mock_settings.AWS_SECRET_ACCESS_KEY = "testing" - yield _BUCKET - - def _pinecone_mock(): """Return a mock Pinecone index with realistic return shapes.""" mock_index = MagicMock() @@ -383,3 +365,198 @@ class TestVectorStoreQdrant: await store.delete("u1", ["v1"]) call_kwargs = mock_client.delete.call_args[1] assert call_kwargs["collection_name"] == "adiuva_vectors" + + +# ── TestStorageRoutes (integration) ─────────────────────────────────── + + +class TestStorageRoutes: + """Integration tests for POST/GET/PUT/DELETE /api/v1/storage/records. + + Pydantic v2 converts JSON string → bytes via ``str.encode('utf-8')``. + So "hello" in JSON becomes ``b"hello"`` on the server. We use plain + ASCII strings as blob values and compute checksums accordingly. + """ + + _BLOB_STR = "encrypted-payload-opaque-to-server" + _BLOB_BYTES = _BLOB_STR.encode() + _BLOB_CHECKSUM = hashlib.sha256(_BLOB_BYTES).hexdigest() + + @classmethod + def _create_payload(cls, blob_str: str | None = None) -> dict: + blob_str = blob_str or cls._BLOB_STR + checksum = hashlib.sha256(blob_str.encode()).hexdigest() + return { + "table": "tasks", + "blob": blob_str, + "checksum": checksum, + } + + def _create_record(self, client, tier="power", blob_str=None): + payload = self._create_payload(blob_str) + return client.post( + "/api/v1/storage/records", + json=payload, + headers=auth_header(tier), + ) + + # ── Create ──────────────────────────────────────────────────────── + + def test_create_record(self, client, s3_bucket) -> None: + resp = self._create_record(client) + assert resp.status_code == 201 + data = resp.json() + assert "id" in data + assert "created_at" in data + + def test_create_record_bad_checksum(self, client, s3_bucket) -> None: + payload = { + "table": "tasks", + "blob": self._BLOB_STR, + "checksum": "0" * 64, + } + resp = client.post( + "/api/v1/storage/records", + json=payload, + headers=auth_header("power"), + ) + assert resp.status_code == 400 + + def test_create_record_free_tier_blocked(self, client, s3_bucket) -> None: + """Free tier has cloud_storage_gb=0 → 402.""" + resp = self._create_record(client, tier="free") + assert resp.status_code == 402 + + def test_create_record_pro_tier_allowed(self, client, s3_bucket) -> None: + """Pro tier has cloud_storage_gb=5 → succeeds for small blob.""" + resp = self._create_record(client, tier="pro") + assert resp.status_code == 201 + + # ── List ────────────────────────────────────────────────────────── + + def test_list_records(self, client, s3_bucket) -> None: + self._create_record(client) + self._create_record(client, blob_str="second-blob") + resp = client.get( + "/api/v1/storage/records", + headers=auth_header("power"), + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data) == 2 + # Each entry has metadata, no blob bytes + for item in data: + assert "id" in item + assert "table" in item + assert "checksum" in item + assert "blob" not in item + + def test_list_records_filter_by_table(self, client, s3_bucket) -> None: + self._create_record(client) + # Create in a different table + note_blob = "note-blob" + payload = { + "table": "notes", + "blob": note_blob, + "checksum": hashlib.sha256(note_blob.encode()).hexdigest(), + } + client.post( + "/api/v1/storage/records", + json=payload, + headers=auth_header("power"), + ) + resp = client.get( + "/api/v1/storage/records?table=notes", + headers=auth_header("power"), + ) + assert resp.status_code == 200 + data = resp.json() + assert len(data) == 1 + assert data[0]["table"] == "notes" + + def test_list_records_isolated_per_user(self, client, s3_bucket) -> None: + """One user's records should not appear in another user's list.""" + self._create_record(client, tier="power") + resp = client.get( + "/api/v1/storage/records", + headers=auth_header("team"), + ) + assert resp.json() == [] + + # ── Download ────────────────────────────────────────────────────── + + def test_download_record(self, client, s3_bucket) -> None: + create_resp = self._create_record(client) + record_id = create_resp.json()["id"] + resp = client.get( + f"/api/v1/storage/records/{record_id}", + headers=auth_header("power"), + ) + assert resp.status_code == 200 + assert resp.content == self._BLOB_BYTES + assert resp.headers["X-Checksum"] == self._BLOB_CHECKSUM + + def test_download_record_not_found(self, client, s3_bucket) -> None: + resp = client.get( + "/api/v1/storage/records/nonexistent-id", + headers=auth_header("power"), + ) + assert resp.status_code == 404 + + # ── Update ──────────────────────────────────────────────────────── + + def test_update_record(self, client, s3_bucket) -> None: + create_resp = self._create_record(client) + record_id = create_resp.json()["id"] + new_blob_str = "updated-encrypted-payload" + new_checksum = hashlib.sha256(new_blob_str.encode()).hexdigest() + resp = client.put( + f"/api/v1/storage/records/{record_id}", + json={"blob": new_blob_str, "checksum": new_checksum}, + headers=auth_header("power"), + ) + assert resp.status_code == 200 + assert resp.json() == {"ok": True} + + # Verify download returns the updated blob + dl = client.get( + f"/api/v1/storage/records/{record_id}", + headers=auth_header("power"), + ) + assert dl.content == new_blob_str.encode() + + def test_update_record_bad_checksum(self, client, s3_bucket) -> None: + create_resp = self._create_record(client) + record_id = create_resp.json()["id"] + resp = client.put( + f"/api/v1/storage/records/{record_id}", + json={"blob": "some-data", "checksum": "0" * 64}, + headers=auth_header("power"), + ) + assert resp.status_code == 400 + + # ── Delete ──────────────────────────────────────────────────────── + + def test_delete_record(self, client, s3_bucket) -> None: + create_resp = self._create_record(client) + record_id = create_resp.json()["id"] + resp = client.delete( + f"/api/v1/storage/records/{record_id}", + headers=auth_header("power"), + ) + assert resp.status_code == 200 + assert resp.json() == {"ok": True} + + # Subsequent GET should return 404 + dl = client.get( + f"/api/v1/storage/records/{record_id}", + headers=auth_header("power"), + ) + assert dl.status_code == 404 + + def test_delete_record_not_found(self, client, s3_bucket) -> None: + resp = client.delete( + "/api/v1/storage/records/nonexistent", + headers=auth_header("power"), + ) + assert resp.status_code == 404