From 9787befd4a042f694be44363959ded8ad550687a Mon Sep 17 00:00:00 2001 From: roberto Date: Mon, 2 Mar 2026 22:41:35 +0100 Subject: [PATCH] step 11 complete: billing service and tier manager Co-Authored-By: Claude Sonnet 4.6 --- BACKEND_PLAN.md | 9 +- app/api/routes/backup.py | 30 +---- app/api/routes/billing.py | 126 ++------------------- app/api/routes/storage.py | 27 +---- app/billing/__init__.py | 4 + app/billing/stripe_service.py | 183 ++++++++++++++++++++++++++++++ app/billing/tier_manager.py | 207 ++++++++++++++++++++++++++++++++++ 7 files changed, 422 insertions(+), 164 deletions(-) create mode 100644 app/billing/stripe_service.py create mode 100644 app/billing/tier_manager.py diff --git a/BACKEND_PLAN.md b/BACKEND_PLAN.md index 90f9656..b450f98 100644 --- a/BACKEND_PLAN.md +++ b/BACKEND_PLAN.md @@ -376,13 +376,13 @@ adiuva-api/ - `async get_earnings(developer_id, period) -> dict` - **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split. -### Step 11 — Billing & Tier management -- [ ] `app/billing/stripe_service.py`: +### Step 11 — Billing & Tier management ✅ +- [x] `app/billing/stripe_service.py`: - `create_checkout_session(user_id, tier) -> str` - `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed` - `get_subscription(user_id) -> dict | None` - `cancel_subscription(user_id) -> None` -- [ ] `app/billing/tier_manager.py`: +- [x] `app/billing/tier_manager.py`: - `TierManager`: - Feature matrix: ```python @@ -433,6 +433,9 @@ adiuva-api/ - `check_feature(user_id, feature) -> bool` - `get_rate_limit(tier) -> int` - `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit +- [x] `app/billing/__init__.py`: exports `stripe_service` and `tier_manager` singletons +- [x] `app/api/routes/billing.py`: refactored to delegate to `StripeService` +- [x] `app/api/routes/storage.py` and `backup.py`: `_check_quota` now delegates to `tier_manager.enforce_quota` / `enforce_backup_quota` - **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat). ### Step 12 — Database (auth/billing/marketplace only) diff --git a/app/api/routes/backup.py b/app/api/routes/backup.py index ff73f11..bb8821a 100644 --- a/app/api/routes/backup.py +++ b/app/api/routes/backup.py @@ -16,6 +16,7 @@ from typing import Any from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status from app.api.deps import get_current_user +from app.billing.tier_manager import tier_manager from app.schemas import BackupMetadata, UserProfile from app.storage.blob_store import BlobStore from app.storage.encryption import reject_if_tampered @@ -27,32 +28,11 @@ _blob_store = BlobStore() # In-memory backup metadata — replaced by PostgreSQL backup_metadata table in Step 12 _backups: dict[str, list[dict[str, Any]]] = {} # user_id → list of backup records -# TODO(Step11/12): replace with TierManager.check_quota(user_id) -_TIER_BACKUP_LIMITS_GB: dict[str, int] = { - "free": 0, - "pro": 5, - "power": 25, - "team": -1, # unlimited -} - -def _check_backup_quota(user_id: str, tier: str, size_bytes: int) -> None: +def _check_backup_quota(user_id: str, size_bytes: int) -> None: """Raise HTTP 402 if the upload would exceed the tier's backup limit.""" - limit_gb = _TIER_BACKUP_LIMITS_GB.get(tier, 0) - if limit_gb == 0: - raise HTTPException( - status_code=status.HTTP_402_PAYMENT_REQUIRED, - detail="Backup is not available on the free tier", - ) - if limit_gb == -1: - return # unlimited - limit_bytes = limit_gb * 1024**3 - used = sum(b["size_bytes"] for b in _backups.get(user_id, [])) - if used + size_bytes > limit_bytes: - raise HTTPException( - status_code=status.HTTP_402_PAYMENT_REQUIRED, - detail=f"Backup quota exceeded for tier '{tier}'", - ) + current = sum(b["size_bytes"] for b in _backups.get(user_id, [])) + tier_manager.enforce_backup_quota(user_id, current_bytes=current, additional_bytes=size_bytes) @router.put("") @@ -69,7 +49,7 @@ async def upload_backup( """ blob = await request.body() reject_if_tampered(blob, x_backup_checksum) - _check_backup_quota(current_user.id, current_user.tier, len(blob)) + _check_backup_quota(current_user.id, len(blob)) s3_key = await _blob_store.upload( current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum diff --git a/app/api/routes/billing.py b/app/api/routes/billing.py index ccc2ca2..6ca1aa7 100644 --- a/app/api/routes/billing.py +++ b/app/api/routes/billing.py @@ -1,44 +1,23 @@ """Billing routes: Stripe checkout, webhook, subscription management. -Subscription records are kept in-memory until Step 12 migrates them to -PostgreSQL (subscriptions table). Stripe calls are gracefully stubbed when -STRIPE_SECRET_KEY is not configured, allowing local development without keys. +Business logic lives in ``app.billing.stripe_service.StripeService``. +The route layer handles HTTP concerns (request parsing, response shaping) +and delegates everything else to the service singleton. """ from __future__ import annotations from typing import Any -import stripe as stripe_lib -from fastapi import APIRouter, Depends, Header, HTTPException, Request, status +from fastapi import APIRouter, Depends, Header, Request, status from pydantic import BaseModel from app.api.deps import get_current_user -from app.config.settings import settings +from app.billing.stripe_service import stripe_service from app.schemas import BillingTier, UserProfile router = APIRouter(prefix="/billing", tags=["billing"]) -# In-memory subscriptions — replaced by PostgreSQL subscriptions table in Step 12 -_subscriptions: dict[str, dict[str, Any]] = {} # user_id → subscription record - -_TIER_PRICE_IDS: dict[str, str] = { - "pro": "price_pro_monthly", # replace with real Stripe price IDs - "power": "price_power_monthly", - "team": "price_team_monthly", -} - - -# ── Helpers ──────────────────────────────────────────────────────────── - -def _stripe_configured() -> bool: - return bool(settings.STRIPE_SECRET_KEY) - - -def _stripe() -> Any: - stripe_lib.api_key = settings.STRIPE_SECRET_KEY - return stripe_lib - # ── Request bodies ───────────────────────────────────────────────────── @@ -57,34 +36,8 @@ async def create_checkout( Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured. """ - if body.tier == "free": - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Cannot create a checkout session for the free tier", - ) - - if _stripe_configured(): - price_id = _TIER_PRICE_IDS.get(body.tier) - if not price_id: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Unknown tier: {body.tier}", - ) - s = _stripe() - session = s.checkout.Session.create( - payment_method_types=["card"], - mode="subscription", - line_items=[{"price": price_id, "quantity": 1}], - success_url=( - "https://app.adiuva.app/billing/success" - "?session_id={CHECKOUT_SESSION_ID}" - ), - cancel_url="https://app.adiuva.app/billing/cancel", - metadata={"user_id": current_user.id, "tier": body.tier}, - ) - return {"checkout_url": session.url} - - return {"checkout_url": "https://stripe.com/stub-checkout"} + url = stripe_service.create_checkout_session(current_user.id, body.tier) + return {"checkout_url": url} @router.post("/webhook", response_model=dict) @@ -98,48 +51,7 @@ async def stripe_webhook( Returns 200 immediately when Stripe is not configured (local dev). """ payload = await request.body() - - if not _stripe_configured(): - return {"ok": True} - - try: - s = _stripe() - event = s.Webhook.construct_event( - payload, stripe_signature, settings.STRIPE_WEBHOOK_SECRET - ) - except stripe_lib.error.SignatureVerificationError: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid Stripe signature", - ) - - event_type: str = event["type"] - data: dict[str, Any] = event["data"]["object"] - - if event_type == "checkout.session.completed": - user_id = data.get("metadata", {}).get("user_id") - tier = data.get("metadata", {}).get("tier", "free") - sub_id = data.get("subscription") - if user_id: - _subscriptions[user_id] = { - "tier": tier, - "stripe_subscription_id": sub_id, - "status": "active", - "current_period_end": None, - } - - elif event_type == "customer.subscription.updated": - # TODO(Step12): look up user_id from stripe_customer_id in DB, then update tier - pass - - elif event_type == "customer.subscription.deleted": - # TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free - pass - - elif event_type == "invoice.payment_failed": - # TODO(Step12): flag subscription as past_due, notify user - pass - + stripe_service.handle_webhook(payload, stripe_signature) return {"ok": True} @@ -148,7 +60,7 @@ async def get_subscription( current_user: UserProfile = Depends(get_current_user), ) -> dict[str, Any]: """Return the current subscription info for the authenticated user.""" - sub = _subscriptions.get(current_user.id) + sub = stripe_service.get_subscription(current_user.id) if sub is None: return { "tier": current_user.tier, @@ -159,26 +71,10 @@ async def get_subscription( return sub -@router.delete("/subscription", response_model=dict) +@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK) async def cancel_subscription( current_user: UserProfile = Depends(get_current_user), ) -> dict[str, bool]: """Cancel the active subscription.""" - sub = _subscriptions.get(current_user.id) - if sub is None or not sub.get("stripe_subscription_id"): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="No active subscription found", - ) - - if _stripe_configured(): - s = _stripe() - s.Subscription.cancel(sub["stripe_subscription_id"]) - - _subscriptions[current_user.id] = { - **sub, - "tier": "free", - "status": "canceled", - } - + stripe_service.cancel_subscription(current_user.id) return {"ok": True} diff --git a/app/api/routes/storage.py b/app/api/routes/storage.py index 8db7067..beb5747 100644 --- a/app/api/routes/storage.py +++ b/app/api/routes/storage.py @@ -14,6 +14,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Response, status from pydantic import BaseModel from app.api.deps import get_current_user +from app.billing.tier_manager import tier_manager from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile from app.storage.blob_store import BlobStore from app.storage.encryption import reject_if_tampered @@ -25,14 +26,6 @@ _blob_store = BlobStore() # In-memory record metadata — replaced by PostgreSQL storage_records table in Step 12 _records: dict[str, dict[str, Any]] = {} -# TODO(Step11/12): replace with TierManager.check_quota(user_id) -_TIER_STORAGE_LIMITS_GB: dict[str, int] = { - "free": 0, - "pro": 5, - "power": 25, - "team": -1, # unlimited -} - # ── Local response schemas ───────────────────────────────────────────── @@ -51,18 +44,10 @@ class _RecordMeta(BaseModel): # ── Helpers ──────────────────────────────────────────────────────────── -def _check_quota(user_id: str, tier: str, additional_bytes: int) -> None: +def _check_quota(user_id: str, additional_bytes: int) -> None: """Raise HTTP 402 if adding ``additional_bytes`` would exceed the tier limit.""" - limit_gb = _TIER_STORAGE_LIMITS_GB.get(tier, 0) - if limit_gb == -1: - return # unlimited - limit_bytes = limit_gb * 1024**3 - used = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id) - if used + additional_bytes > limit_bytes: - raise HTTPException( - status_code=status.HTTP_402_PAYMENT_REQUIRED, - detail=f"Storage quota exceeded for tier '{tier}'", - ) + current = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id) + tier_manager.enforce_quota(user_id, current_bytes=current, additional_bytes=additional_bytes) def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]: @@ -83,7 +68,7 @@ async def create_record( ) -> _CreateResponse: """Upload a new E2E-encrypted blob. Verifies checksum before storing.""" reject_if_tampered(body.blob, body.checksum) - _check_quota(current_user.id, current_user.tier, len(body.blob)) + _check_quota(current_user.id, len(body.blob)) record_id = str(uuid.uuid4()) now = int(time.time() * 1000) @@ -159,7 +144,7 @@ async def update_record( delta = len(body.blob) - record["size_bytes"] if delta > 0: - _check_quota(current_user.id, current_user.tier, delta) + _check_quota(current_user.id, delta) s3_key = await _blob_store.upload( current_user.id, record["table"], record_id, body.blob, body.checksum diff --git a/app/billing/__init__.py b/app/billing/__init__.py index e69de29..ef83f83 100644 --- a/app/billing/__init__.py +++ b/app/billing/__init__.py @@ -0,0 +1,4 @@ +from app.billing.stripe_service import stripe_service +from app.billing.tier_manager import tier_manager + +__all__ = ["stripe_service", "tier_manager"] diff --git a/app/billing/stripe_service.py b/app/billing/stripe_service.py new file mode 100644 index 0000000..0c68ded --- /dev/null +++ b/app/billing/stripe_service.py @@ -0,0 +1,183 @@ +"""Stripe service: checkout sessions, webhook handling, subscription management. + +Subscriptions are stored in-memory until Step 12 migrates them to the +PostgreSQL ``subscriptions`` table. All Stripe calls are gracefully stubbed +when ``STRIPE_SECRET_KEY`` is not configured, enabling local development +without live credentials. +""" + +from __future__ import annotations + +from typing import Any + +import stripe as stripe_lib +from fastapi import HTTPException, status + +from app.config.settings import settings + +# Stripe price IDs per tier — replace with real IDs in production .env +TIER_PRICE_IDS: dict[str, str] = { + "pro": "price_pro_monthly", + "power": "price_power_monthly", + "team": "price_team_monthly", +} + + +class StripeService: + """Wraps all Stripe interactions and owns the in-memory subscription store. + + Step 12 will replace ``_subscriptions`` with real PostgreSQL queries. + """ + + def __init__(self) -> None: + # user_id → subscription record dict + # Replaced by the ``subscriptions`` table in Step 12. + self._subscriptions: dict[str, dict[str, Any]] = {} + + # ── Internal helpers ──────────────────────────────────────────────── + + def _configured(self) -> bool: + return bool(settings.STRIPE_SECRET_KEY) + + def _client(self) -> Any: + stripe_lib.api_key = settings.STRIPE_SECRET_KEY + return stripe_lib + + # ── Public API ────────────────────────────────────────────────────── + + def create_checkout_session( + self, + user_id: str, + tier: str, + success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}", + cancel_url: str = "https://app.adiuva.app/billing/cancel", + ) -> str: + """Create a Stripe checkout session and return the URL. + + Returns a stub URL when Stripe is not configured. + Raises ``HTTP 400`` for the free tier or an unknown tier. + """ + if tier == "free": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot create a checkout session for the free tier", + ) + + price_id = TIER_PRICE_IDS.get(tier) + if not price_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unknown tier: {tier}", + ) + + if not self._configured(): + return "https://stripe.com/stub-checkout" + + s = self._client() + session = s.checkout.Session.create( + payment_method_types=["card"], + mode="subscription", + line_items=[{"price": price_id, "quantity": 1}], + success_url=success_url, + cancel_url=cancel_url, + metadata={"user_id": user_id, "tier": tier}, + ) + return session.url + + def handle_webhook(self, payload: bytes, sig_header: str) -> None: + """Process a Stripe webhook event. + + Verifies the signature, then dispatches on event type. + Raises ``HTTP 400`` on signature mismatch. + No-ops when Stripe is not configured. + """ + if not self._configured(): + return + + try: + s = self._client() + event = s.Webhook.construct_event( + payload, sig_header, settings.STRIPE_WEBHOOK_SECRET + ) + except stripe_lib.error.SignatureVerificationError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid Stripe signature", + ) + + event_type: str = event["type"] + data: dict[str, Any] = event["data"]["object"] + + if event_type == "checkout.session.completed": + user_id = data.get("metadata", {}).get("user_id") + tier = data.get("metadata", {}).get("tier", "free") + sub_id = data.get("subscription") + period_end = data.get("current_period_end") + if user_id: + self._subscriptions[user_id] = { + "tier": tier, + "stripe_subscription_id": sub_id, + "status": "active", + "current_period_end": period_end, + } + + elif event_type == "customer.subscription.updated": + # TODO(Step12): look up user_id from stripe_customer_id in DB, update tier + sub_id = data.get("id") + new_status = data.get("status") + period_end = data.get("current_period_end") + for record in self._subscriptions.values(): + if record.get("stripe_subscription_id") == sub_id: + record["status"] = new_status + record["current_period_end"] = period_end + break + + elif event_type == "customer.subscription.deleted": + # TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free + sub_id = data.get("id") + for user_id, record in self._subscriptions.items(): + if record.get("stripe_subscription_id") == sub_id: + self._subscriptions[user_id] = { + **record, + "tier": "free", + "status": "canceled", + } + break + + elif event_type == "invoice.payment_failed": + # TODO(Step12): flag subscription as past_due, notify user + sub_id = data.get("subscription") + for record in self._subscriptions.values(): + if record.get("stripe_subscription_id") == sub_id: + record["status"] = "past_due" + break + + def get_subscription(self, user_id: str) -> dict[str, Any] | None: + """Return the subscription record for ``user_id``, or ``None`` if absent.""" + return self._subscriptions.get(user_id) + + def cancel_subscription(self, user_id: str) -> None: + """Cancel the user's Stripe subscription and downgrade them to free. + + Raises ``HTTP 404`` when no active subscription exists. + """ + sub = self._subscriptions.get(user_id) + if sub is None or not sub.get("stripe_subscription_id"): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No active subscription found", + ) + + if self._configured(): + s = self._client() + s.Subscription.cancel(sub["stripe_subscription_id"]) + + self._subscriptions[user_id] = { + **sub, + "tier": "free", + "status": "canceled", + } + + +# Module-level singleton shared across the app. +stripe_service = StripeService() diff --git a/app/billing/tier_manager.py b/app/billing/tier_manager.py new file mode 100644 index 0000000..fbd6e5d --- /dev/null +++ b/app/billing/tier_manager.py @@ -0,0 +1,207 @@ +"""Tier manager: feature matrix and quota enforcement. + +``TierManager`` is the single source of truth for what each billing tier +allows. ``get_tier`` reads from the ``StripeService`` in-memory store until +Step 12 replaces it with a live PostgreSQL lookup. +""" + +from __future__ import annotations + +from typing import Any + +from fastapi import HTTPException, status + +from app.schemas import BillingTier + +# Feature matrix per tier. -1 means unlimited; 0 means disabled. +FEATURES: dict[str, dict[str, Any]] = { + "free": { + "agents": 3, + "batch_active": 2, + "cloud_storage_gb": 0, + "backup_gb": 0, + "providers": 1, + "batch_builder": False, + "plugin_marketplace": False, + "sso": False, + }, + "pro": { + "agents": -1, # unlimited + "batch_active": 10, + "cloud_storage_gb": 5, + "backup_gb": 5, + "providers": -1, + "batch_builder": False, + "plugin_marketplace": False, + "sso": False, + }, + "power": { + "agents": -1, + "batch_active": -1, # unlimited + "cloud_storage_gb": 25, + "backup_gb": 25, + "providers": -1, + "batch_builder": True, + "plugin_marketplace": True, + "sso": False, + }, + "team": { + "agents": -1, + "batch_active": -1, + "cloud_storage_gb": -1, # unlimited + "backup_gb": -1, # unlimited + "providers": -1, + "batch_builder": True, + "plugin_marketplace": True, + "sso": True, + }, +} + +# Requests-per-minute limit per tier. +RATE_LIMITS: dict[str, int] = { + "free": 20, + "pro": 60, + "power": 120, + "team": 200, +} + + +class TierManager: + """Centralises tier feature-gating, rate-limit lookups, and quota checks. + + ``get_tier`` consults the ``StripeService`` singleton. Step 12 will + replace that with a PostgreSQL query so that the tier is always fresh. + """ + + # ── Tier lookup ───────────────────────────────────────────────────── + + def get_tier(self, user_id: str) -> BillingTier: + """Return the current billing tier for ``user_id``. + + Falls back to ``'free'`` when no subscription record exists. + Step 12 will replace this with a live DB lookup. + """ + # Import here to avoid circular imports at module load time. + from app.billing.stripe_service import stripe_service # noqa: PLC0415 + + sub = stripe_service.get_subscription(user_id) + if sub is None: + return "free" + tier = sub.get("tier", "free") + # Validate against known tiers; unknown values fall back to free. + if tier not in FEATURES: + return "free" + return tier # type: ignore[return-value] + + # ── Feature access ─────────────────────────────────────────────────── + + def check_feature(self, user_id: str, feature: str) -> bool: + """Return ``True`` if ``user_id``'s current tier has ``feature`` enabled. + + For numeric features, any value > 0 or -1 (unlimited) counts as enabled. + """ + tier = self.get_tier(user_id) + value = FEATURES[tier].get(feature) + if value is None: + return False + if isinstance(value, bool): + return value + # Numeric: -1 means unlimited (enabled), 0 means disabled. + return value != 0 + + def require_feature(self, user_id: str, feature: str, tier_name: str = "") -> None: + """Raise ``HTTP 403`` if ``user_id`` does not have ``feature``. + + ``tier_name`` is used in the error message to tell users which tier + they need to upgrade to. + """ + if not self.check_feature(user_id, feature): + detail = ( + f"Feature '{feature}' requires {tier_name} tier or above." + if tier_name + else f"Feature '{feature}' is not available on your current tier." + ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail) + + # ── Rate limiting ──────────────────────────────────────────────────── + + def get_rate_limit(self, tier: BillingTier) -> int: + """Return the requests-per-minute limit for ``tier``.""" + return RATE_LIMITS.get(tier, RATE_LIMITS["free"]) + + # ── Storage quota ──────────────────────────────────────────────────── + + def check_quota( + self, + user_id: str, + current_bytes: int = 0, + additional_bytes: int = 0, + ) -> bool: + """Return ``True`` if ``user_id`` can store ``additional_bytes`` more data. + + ``current_bytes`` is the user's current storage usage (from the + caller's record-keeping). Step 12 will remove these parameters and + query the DB directly. + + Returns ``False`` if the tier has no storage allocation at all + (free tier), or if ``current_bytes + additional_bytes`` would exceed + the tier's ``cloud_storage_gb`` limit. + """ + tier = self.get_tier(user_id) + limit_gb: int = FEATURES[tier]["cloud_storage_gb"] + if limit_gb == 0: + return False # tier has no storage + if limit_gb == -1: + return True # unlimited + limit_bytes = limit_gb * 1024 ** 3 + return current_bytes + additional_bytes <= limit_bytes + + def enforce_quota( + self, + user_id: str, + current_bytes: int = 0, + additional_bytes: int = 0, + ) -> None: + """Raise ``HTTP 402`` if ``user_id`` would exceed their storage quota.""" + tier = self.get_tier(user_id) + limit_gb: int = FEATURES[tier]["cloud_storage_gb"] + if limit_gb == 0: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Cloud storage is not available on the '{tier}' tier", + ) + if limit_gb == -1: + return # unlimited + limit_bytes = limit_gb * 1024 ** 3 + if current_bytes + additional_bytes > limit_bytes: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Storage quota exceeded for tier '{tier}'", + ) + + def enforce_backup_quota( + self, + user_id: str, + current_bytes: int = 0, + additional_bytes: int = 0, + ) -> None: + """Raise ``HTTP 402`` if ``user_id`` would exceed their backup quota.""" + tier = self.get_tier(user_id) + limit_gb: int = FEATURES[tier]["backup_gb"] + if limit_gb == 0: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Backup is not available on the '{tier}' tier", + ) + if limit_gb == -1: + return # unlimited + limit_bytes = limit_gb * 1024 ** 3 + if current_bytes + additional_bytes > limit_bytes: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Backup quota exceeded for tier '{tier}'", + ) + + +# Module-level singleton shared across the app. +tier_manager = TierManager()