"""Tier manager: feature matrix and quota enforcement. ``TierManager`` is the single source of truth for what each billing tier allows. ``get_tier`` reads from the ``StripeService`` in-memory store until Step 12 replaces it with a live PostgreSQL lookup. """ from __future__ import annotations from typing import Any from fastapi import HTTPException, status from app.schemas import BillingTier # Feature matrix per tier. -1 means unlimited; 0 means disabled. FEATURES: dict[str, dict[str, Any]] = { "free": { "agents": 3, "batch_active": 2, "cloud_storage_gb": 0, "backup_gb": 0, "providers": 1, "batch_builder": False, "plugin_marketplace": False, "sso": False, }, "pro": { "agents": -1, # unlimited "batch_active": 10, "cloud_storage_gb": 5, "backup_gb": 5, "providers": -1, "batch_builder": False, "plugin_marketplace": False, "sso": False, }, "power": { "agents": -1, "batch_active": -1, # unlimited "cloud_storage_gb": 25, "backup_gb": 25, "providers": -1, "batch_builder": True, "plugin_marketplace": True, "sso": False, }, "team": { "agents": -1, "batch_active": -1, "cloud_storage_gb": -1, # unlimited "backup_gb": -1, # unlimited "providers": -1, "batch_builder": True, "plugin_marketplace": True, "sso": True, }, } # Requests-per-minute limit per tier. RATE_LIMITS: dict[str, int] = { "free": 20, "pro": 60, "power": 120, "team": 200, } class TierManager: """Centralises tier feature-gating, rate-limit lookups, and quota checks. ``get_tier`` consults the ``StripeService`` singleton. Step 12 will replace that with a PostgreSQL query so that the tier is always fresh. """ # ── Tier lookup ───────────────────────────────────────────────────── def get_tier(self, user_id: str) -> BillingTier: """Return the current billing tier for ``user_id``. Falls back to ``'free'`` when no subscription record exists. Step 12 will replace this with a live DB lookup. """ # Import here to avoid circular imports at module load time. from app.billing.stripe_service import stripe_service # noqa: PLC0415 sub = stripe_service.get_subscription(user_id) if sub is None: return "free" tier = sub.get("tier", "free") # Validate against known tiers; unknown values fall back to free. if tier not in FEATURES: return "free" return tier # type: ignore[return-value] # ── Feature access ─────────────────────────────────────────────────── def check_feature(self, user_id: str, feature: str) -> bool: """Return ``True`` if ``user_id``'s current tier has ``feature`` enabled. For numeric features, any value > 0 or -1 (unlimited) counts as enabled. """ tier = self.get_tier(user_id) value = FEATURES[tier].get(feature) if value is None: return False if isinstance(value, bool): return value # Numeric: -1 means unlimited (enabled), 0 means disabled. return value != 0 def require_feature(self, user_id: str, feature: str, tier_name: str = "") -> None: """Raise ``HTTP 403`` if ``user_id`` does not have ``feature``. ``tier_name`` is used in the error message to tell users which tier they need to upgrade to. """ if not self.check_feature(user_id, feature): detail = ( f"Feature '{feature}' requires {tier_name} tier or above." if tier_name else f"Feature '{feature}' is not available on your current tier." ) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail) # ── Rate limiting ──────────────────────────────────────────────────── def get_rate_limit(self, tier: BillingTier) -> int: """Return the requests-per-minute limit for ``tier``.""" return RATE_LIMITS.get(tier, RATE_LIMITS["free"]) # ── Storage quota ──────────────────────────────────────────────────── def check_quota( self, user_id: str, current_bytes: int = 0, additional_bytes: int = 0, ) -> bool: """Return ``True`` if ``user_id`` can store ``additional_bytes`` more data. ``current_bytes`` is the user's current storage usage (from the caller's record-keeping). Step 12 will remove these parameters and query the DB directly. Returns ``False`` if the tier has no storage allocation at all (free tier), or if ``current_bytes + additional_bytes`` would exceed the tier's ``cloud_storage_gb`` limit. """ tier = self.get_tier(user_id) limit_gb: int = FEATURES[tier]["cloud_storage_gb"] if limit_gb == 0: return False # tier has no storage if limit_gb == -1: return True # unlimited limit_bytes = limit_gb * 1024 ** 3 return current_bytes + additional_bytes <= limit_bytes def enforce_quota( self, user_id: str, current_bytes: int = 0, additional_bytes: int = 0, ) -> None: """Raise ``HTTP 402`` if ``user_id`` would exceed their storage quota.""" tier = self.get_tier(user_id) limit_gb: int = FEATURES[tier]["cloud_storage_gb"] if limit_gb == 0: raise HTTPException( status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"Cloud storage is not available on the '{tier}' tier", ) if limit_gb == -1: return # unlimited limit_bytes = limit_gb * 1024 ** 3 if current_bytes + additional_bytes > limit_bytes: raise HTTPException( status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"Storage quota exceeded for tier '{tier}'", ) def enforce_backup_quota( self, user_id: str, current_bytes: int = 0, additional_bytes: int = 0, ) -> None: """Raise ``HTTP 402`` if ``user_id`` would exceed their backup quota.""" tier = self.get_tier(user_id) limit_gb: int = FEATURES[tier]["backup_gb"] if limit_gb == 0: raise HTTPException( status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"Backup is not available on the '{tier}' tier", ) if limit_gb == -1: return # unlimited limit_bytes = limit_gb * 1024 ** 3 if current_bytes + additional_bytes > limit_bytes: raise HTTPException( status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"Backup quota exceeded for tier '{tier}'", ) # Module-level singleton shared across the app. tier_manager = TierManager()