"""Tier manager: feature matrix and quota enforcement. Single source of truth for what each billing tier allows. Other services can query the /tier/{user_id} endpoint or rely on the X-User-Tier header injected by Traefik. """ from __future__ import annotations from typing import Any from fastapi import HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from shared.config import settings from shared.models import Subscription from shared.schemas import BillingTier # Feature matrix per tier. -1 means unlimited; 0 means disabled. FEATURES: dict[str, dict[str, Any]] = { "free": { "agents": 3, "batch_active": 2, "batch_runs_per_day": 5, "cloud_storage_gb": 0, "backup_gb": 0, "providers": 1, "batch_builder": False, "plugin_marketplace": False, "sso": False, }, "pro": { "agents": -1, "batch_active": 10, "batch_runs_per_day": 50, "cloud_storage_gb": 5, "backup_gb": 5, "providers": -1, "batch_builder": False, "plugin_marketplace": False, "sso": False, }, "power": { "agents": -1, "batch_active": -1, "batch_runs_per_day": -1, "cloud_storage_gb": 25, "backup_gb": 25, "providers": -1, "batch_builder": True, "plugin_marketplace": True, "sso": False, }, "team": { "agents": -1, "batch_active": -1, "batch_runs_per_day": -1, "cloud_storage_gb": -1, "backup_gb": -1, "providers": -1, "batch_builder": True, "plugin_marketplace": True, "sso": True, }, } # Requests-per-minute limit per tier. RATE_LIMITS: dict[str, int] = { "free": 20, "pro": 60, "power": 120, "team": 200, } class TierManager: """Centralises tier feature-gating, rate-limit lookups, and quota checks.""" async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier: """Return the current billing tier for user_id from the DB.""" result = await db.execute( select(Subscription.tier).where(Subscription.user_id == user_id) ) tier: str | None = result.scalar_one_or_none() if tier is None or tier not in FEATURES: return "power" if settings.ENV == "dev" else "free" return tier # type: ignore[return-value] def get_features(self, tier: BillingTier) -> dict[str, Any]: """Return the full feature dict for a tier.""" return FEATURES.get(tier, FEATURES["free"]) def check_feature(self, tier: BillingTier, feature: str) -> bool: """Return True if tier has feature enabled.""" value = FEATURES.get(tier, FEATURES["free"]).get(feature) if value is None: return False if isinstance(value, bool): return value return value != 0 def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None: """Raise HTTP 403 if tier does not have feature.""" if not self.check_feature(tier, feature): detail = ( f"Feature '{feature}' requires {tier_name} tier or above." if tier_name else f"Feature '{feature}' is not available on your current tier." ) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail) def get_rate_limit(self, tier: BillingTier) -> int: """Return the requests-per-minute limit for tier.""" return RATE_LIMITS.get(tier, RATE_LIMITS["free"]) def enforce_quota( self, tier: BillingTier, current_bytes: int = 0, additional_bytes: int = 0, ) -> None: """Raise HTTP 402 if the user would exceed their cloud storage quota.""" limit_gb: int = FEATURES[tier]["cloud_storage_gb"] if limit_gb == 0: raise HTTPException( status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"Cloud storage is not available on the '{tier}' tier", ) if limit_gb == -1: return limit_bytes = limit_gb * 1024 ** 3 if current_bytes + additional_bytes > limit_bytes: raise HTTPException( status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"Storage quota exceeded for tier '{tier}'", ) def enforce_backup_quota( self, tier: BillingTier, current_bytes: int = 0, additional_bytes: int = 0, ) -> None: """Raise HTTP 402 if the user would exceed their backup quota.""" limit_gb: int = FEATURES[tier]["backup_gb"] if limit_gb == 0: raise HTTPException( status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"Backup is not available on the '{tier}' tier", ) if limit_gb == -1: return limit_bytes = limit_gb * 1024 ** 3 if current_bytes + additional_bytes > limit_bytes: raise HTTPException( status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"Backup quota exceeded for tier '{tier}'", ) def check_quota( self, tier: BillingTier, current_bytes: int = 0, additional_bytes: int = 0, ) -> bool: """Return True if the user can store additional_bytes more data.""" limit_gb: int = FEATURES[tier]["cloud_storage_gb"] if limit_gb == 0: return False if limit_gb == -1: return True limit_bytes = limit_gb * 1024 ** 3 return current_bytes + additional_bytes <= limit_bytes # Module-level singleton tier_manager = TierManager()