This commit is contained in:
2026-03-03 12:39:32 +01:00
parent 9787befd4a
commit 5d485b3665
12 changed files with 999 additions and 165 deletions

View File

@@ -1,8 +1,9 @@
"""Tier manager: feature matrix and quota enforcement.
``TierManager`` is the single source of truth for what each billing tier
allows. ``get_tier`` reads from the ``StripeService`` in-memory store until
Step 12 replaces it with a live PostgreSQL lookup.
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
Quota-enforcement helpers take ``tier`` directly — the caller already has it
from ``current_user.tier`` (provided by ``get_current_user``).
"""
from __future__ import annotations
@@ -10,6 +11,8 @@ from __future__ import annotations
from typing import Any
from fastapi import HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.schemas import BillingTier
@@ -67,55 +70,42 @@ RATE_LIMITS: dict[str, int] = {
class TierManager:
"""Centralises tier feature-gating, rate-limit lookups, and quota checks.
``get_tier`` consults the ``StripeService`` singleton. Step 12 will
replace that with a PostgreSQL query so that the tier is always fresh.
"""
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
# ── Tier lookup ─────────────────────────────────────────────────────
def get_tier(self, user_id: str) -> BillingTier:
"""Return the current billing tier for ``user_id``.
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
"""Return the current billing tier for ``user_id`` from the DB.
Falls back to ``'free'`` when no subscription record exists.
Step 12 will replace this with a live DB lookup.
Falls back to ``'free'`` when no subscription row exists.
"""
# Import here to avoid circular imports at module load time.
from app.billing.stripe_service import stripe_service # noqa: PLC0415
from app.models import Subscription # noqa: PLC0415
sub = stripe_service.get_subscription(user_id)
if sub is None:
return "free"
tier = sub.get("tier", "free")
# Validate against known tiers; unknown values fall back to free.
if tier not in FEATURES:
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
tier: str | None = result.scalar_one_or_none()
if tier is None or tier not in FEATURES:
return "free"
return tier # type: ignore[return-value]
# ── Feature access ───────────────────────────────────────────────────
def check_feature(self, user_id: str, feature: str) -> bool:
"""Return ``True`` if ``user_id``'s current tier has ``feature`` enabled.
def check_feature(self, tier: BillingTier, feature: str) -> bool:
"""Return ``True`` if ``tier`` has ``feature`` enabled.
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
"""
tier = self.get_tier(user_id)
value = FEATURES[tier].get(feature)
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
if value is None:
return False
if isinstance(value, bool):
return value
# Numeric: -1 means unlimited (enabled), 0 means disabled.
return value != 0
def require_feature(self, user_id: str, feature: str, tier_name: str = "") -> None:
"""Raise ``HTTP 403`` if ``user_id`` does not have ``feature``.
``tier_name`` is used in the error message to tell users which tier
they need to upgrade to.
"""
if not self.check_feature(user_id, feature):
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
if not self.check_feature(tier, feature):
detail = (
f"Feature '{feature}' requires {tier_name} tier or above."
if tier_name
@@ -131,39 +121,17 @@ class TierManager:
# ── Storage quota ────────────────────────────────────────────────────
def check_quota(
self,
user_id: str,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> bool:
"""Return ``True`` if ``user_id`` can store ``additional_bytes`` more data.
``current_bytes`` is the user's current storage usage (from the
caller's record-keeping). Step 12 will remove these parameters and
query the DB directly.
Returns ``False`` if the tier has no storage allocation at all
(free tier), or if ``current_bytes + additional_bytes`` would exceed
the tier's ``cloud_storage_gb`` limit.
"""
tier = self.get_tier(user_id)
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
return False # tier has no storage
if limit_gb == -1:
return True # unlimited
limit_bytes = limit_gb * 1024 ** 3
return current_bytes + additional_bytes <= limit_bytes
def enforce_quota(
self,
user_id: str,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if ``user_id`` would exceed their storage quota."""
tier = self.get_tier(user_id)
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
``tier`` is the caller's current tier (from ``current_user.tier``).
``current_bytes`` is the total bytes already stored (queried by caller).
"""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
raise HTTPException(
@@ -181,12 +149,11 @@ class TierManager:
def enforce_backup_quota(
self,
user_id: str,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if ``user_id`` would exceed their backup quota."""
tier = self.get_tier(user_id)
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
limit_gb: int = FEATURES[tier]["backup_gb"]
if limit_gb == 0:
raise HTTPException(
@@ -202,6 +169,21 @@ class TierManager:
detail=f"Backup quota exceeded for tier '{tier}'",
)
def check_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> bool:
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
return False
if limit_gb == -1:
return True
limit_bytes = limit_gb * 1024 ** 3
return current_bytes + additional_bytes <= limit_bytes
# Module-level singleton shared across the app.
tier_manager = TierManager()