Step 12 - completed

This commit is contained in:
2026-03-03 14:53:34 +01:00
parent 5d485b3665
commit d0b303e745
13 changed files with 950 additions and 487 deletions

View File

@@ -1,20 +1,23 @@
"""Storage routes: CRUD for E2E-encrypted cloud records.
Blobs are stored in S3 via BlobStore. Record metadata is kept in an
in-memory dict until Step 12 migrates it to PostgreSQL (storage_records table).
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
PostgreSQL ``storage_records`` table.
"""
from __future__ import annotations
import time
import uuid
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.db import get_session
from app.models import StorageRecord
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
@@ -23,9 +26,6 @@ router = APIRouter(prefix="/storage", tags=["storage"])
_blob_store = BlobStore()
# In-memory record metadata — replaced by PostgreSQL storage_records table in Step 12
_records: dict[str, dict[str, Any]] = {}
# ── Local response schemas ─────────────────────────────────────────────
@@ -44,17 +44,34 @@ class _RecordMeta(BaseModel):
# ── Helpers ────────────────────────────────────────────────────────────
def _check_quota(user_id: str, additional_bytes: int) -> None:
"""Raise HTTP 402 if adding ``additional_bytes`` would exceed the tier limit."""
current = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id)
tier_manager.enforce_quota(user_id, current_bytes=current, additional_bytes=additional_bytes)
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
"""Return total bytes stored by *user_id*."""
result = await db.execute(
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
StorageRecord.user_id == user_id
)
)
return int(result.scalar_one())
def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]:
"""Look up a record and verify ownership. Always returns 404 on mismatch
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
current = await _current_usage_bytes(user.id, db)
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
async def _get_record_for_user(
record_id: str, user_id: str, db: AsyncSession
) -> StorageRecord:
"""Look up a record and verify ownership. Returns 404 on mismatch
to prevent user enumeration attacks."""
record = _records.get(record_id)
if record is None or record["user_id"] != user_id:
result = await db.execute(
select(StorageRecord).where(
StorageRecord.id == record_id, StorageRecord.user_id == user_id
)
)
record = result.scalar_one_or_none()
if record is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
return record
@@ -65,30 +82,32 @@ def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]:
async def create_record(
body: StorageRecordCreate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _CreateResponse:
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
reject_if_tampered(body.blob, body.checksum)
_check_quota(current_user.id, len(body.blob))
await _check_quota(current_user, len(body.blob), db)
record_id = str(uuid.uuid4())
now = int(time.time() * 1000)
s3_key = await _blob_store.upload(
current_user.id, body.table, record_id, body.blob, body.checksum
)
_records[record_id] = {
"id": record_id,
"user_id": current_user.id,
"table": body.table,
"s3_key": s3_key,
"checksum": body.checksum,
"size_bytes": len(body.blob),
"created_at": now,
"updated_at": now,
}
record = StorageRecord(
id=record_id,
user_id=current_user.id,
table_name=body.table,
s3_key=s3_key,
checksum=body.checksum,
size_bytes=len(body.blob),
)
db.add(record)
await db.commit()
await db.refresh(record)
return _CreateResponse(id=record_id, created_at=now)
created_at_ms = int(record.created_at.timestamp() * 1000)
return _CreateResponse(id=record_id, created_at=created_at_ms)
@router.get("/records", response_model=list[_RecordMeta])
@@ -97,23 +116,26 @@ async def list_records(
page: int = Query(default=1, ge=1),
limit: int = Query(default=50, ge=1, le=200),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[_RecordMeta]:
"""List record metadata for the authenticated user. Blob bytes are never returned."""
all_records = [
r for r in _records.values()
if r["user_id"] == current_user.id and (table is None or r["table"] == table)
]
start = (page - 1) * limit
page_records = all_records[start : start + limit]
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
if table is not None:
query = query.where(StorageRecord.table_name == table)
query = query.offset((page - 1) * limit).limit(limit)
result = await db.execute(query)
rows = result.scalars().all()
return [
_RecordMeta(
id=r["id"],
table=r["table"],
checksum=r["checksum"],
created_at=r["created_at"],
updated_at=r["updated_at"],
id=r.id,
table=r.table_name,
checksum=r.checksum,
created_at=int(r.created_at.timestamp() * 1000),
updated_at=int(r.updated_at.timestamp() * 1000),
)
for r in page_records
for r in rows
]
@@ -121,14 +143,15 @@ async def list_records(
async def download_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> Response:
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
record = _get_record_for_user(record_id, current_user.id)
blob = await _blob_store.download(current_user.id, record["s3_key"])
record = await _get_record_for_user(record_id, current_user.id, db)
blob = await _blob_store.download(current_user.id, record.s3_key)
return Response(
content=blob,
media_type="application/octet-stream",
headers={"X-Checksum": record["checksum"]},
headers={"X-Checksum": record.checksum},
)
@@ -137,23 +160,24 @@ async def update_record(
record_id: str,
body: StorageRecordUpdate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Replace the blob for an existing record. Verifies checksum before storing."""
record = _get_record_for_user(record_id, current_user.id)
record = await _get_record_for_user(record_id, current_user.id, db)
reject_if_tampered(body.blob, body.checksum)
delta = len(body.blob) - record["size_bytes"]
delta = len(body.blob) - record.size_bytes
if delta > 0:
_check_quota(current_user.id, delta)
await _check_quota(current_user, delta, db)
s3_key = await _blob_store.upload(
current_user.id, record["table"], record_id, body.blob, body.checksum
current_user.id, record.table_name, record_id, body.blob, body.checksum
)
record["s3_key"] = s3_key
record["checksum"] = body.checksum
record["size_bytes"] = len(body.blob)
record["updated_at"] = int(time.time() * 1000)
record.s3_key = s3_key
record.checksum = body.checksum
record.size_bytes = len(body.blob)
await db.commit()
return {"ok": True}
@@ -162,9 +186,11 @@ async def update_record(
async def delete_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a record and its S3 blob."""
record = _get_record_for_user(record_id, current_user.id)
await _blob_store.delete(current_user.id, record["s3_key"])
del _records[record_id]
record = await _get_record_for_user(record_id, current_user.id, db)
await _blob_store.delete(current_user.id, record.s3_key)
await db.delete(record)
await db.commit()
return {"ok": True}