Files
adiuva-api/app/billing/tier_manager.py
2026-03-02 22:41:35 +01:00

208 lines
7.4 KiB
Python

"""Tier manager: feature matrix and quota enforcement.
``TierManager`` is the single source of truth for what each billing tier
allows. ``get_tier`` reads from the ``StripeService`` in-memory store until
Step 12 replaces it with a live PostgreSQL lookup.
"""
from __future__ import annotations
from typing import Any
from fastapi import HTTPException, status
from app.schemas import BillingTier
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
FEATURES: dict[str, dict[str, Any]] = {
"free": {
"agents": 3,
"batch_active": 2,
"cloud_storage_gb": 0,
"backup_gb": 0,
"providers": 1,
"batch_builder": False,
"plugin_marketplace": False,
"sso": False,
},
"pro": {
"agents": -1, # unlimited
"batch_active": 10,
"cloud_storage_gb": 5,
"backup_gb": 5,
"providers": -1,
"batch_builder": False,
"plugin_marketplace": False,
"sso": False,
},
"power": {
"agents": -1,
"batch_active": -1, # unlimited
"cloud_storage_gb": 25,
"backup_gb": 25,
"providers": -1,
"batch_builder": True,
"plugin_marketplace": True,
"sso": False,
},
"team": {
"agents": -1,
"batch_active": -1,
"cloud_storage_gb": -1, # unlimited
"backup_gb": -1, # unlimited
"providers": -1,
"batch_builder": True,
"plugin_marketplace": True,
"sso": True,
},
}
# Requests-per-minute limit per tier.
RATE_LIMITS: dict[str, int] = {
"free": 20,
"pro": 60,
"power": 120,
"team": 200,
}
class TierManager:
"""Centralises tier feature-gating, rate-limit lookups, and quota checks.
``get_tier`` consults the ``StripeService`` singleton. Step 12 will
replace that with a PostgreSQL query so that the tier is always fresh.
"""
# ── Tier lookup ─────────────────────────────────────────────────────
def get_tier(self, user_id: str) -> BillingTier:
"""Return the current billing tier for ``user_id``.
Falls back to ``'free'`` when no subscription record exists.
Step 12 will replace this with a live DB lookup.
"""
# Import here to avoid circular imports at module load time.
from app.billing.stripe_service import stripe_service # noqa: PLC0415
sub = stripe_service.get_subscription(user_id)
if sub is None:
return "free"
tier = sub.get("tier", "free")
# Validate against known tiers; unknown values fall back to free.
if tier not in FEATURES:
return "free"
return tier # type: ignore[return-value]
# ── Feature access ───────────────────────────────────────────────────
def check_feature(self, user_id: str, feature: str) -> bool:
"""Return ``True`` if ``user_id``'s current tier has ``feature`` enabled.
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
"""
tier = self.get_tier(user_id)
value = FEATURES[tier].get(feature)
if value is None:
return False
if isinstance(value, bool):
return value
# Numeric: -1 means unlimited (enabled), 0 means disabled.
return value != 0
def require_feature(self, user_id: str, feature: str, tier_name: str = "") -> None:
"""Raise ``HTTP 403`` if ``user_id`` does not have ``feature``.
``tier_name`` is used in the error message to tell users which tier
they need to upgrade to.
"""
if not self.check_feature(user_id, feature):
detail = (
f"Feature '{feature}' requires {tier_name} tier or above."
if tier_name
else f"Feature '{feature}' is not available on your current tier."
)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
# ── Rate limiting ────────────────────────────────────────────────────
def get_rate_limit(self, tier: BillingTier) -> int:
"""Return the requests-per-minute limit for ``tier``."""
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
# ── Storage quota ────────────────────────────────────────────────────
def check_quota(
self,
user_id: str,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> bool:
"""Return ``True`` if ``user_id`` can store ``additional_bytes`` more data.
``current_bytes`` is the user's current storage usage (from the
caller's record-keeping). Step 12 will remove these parameters and
query the DB directly.
Returns ``False`` if the tier has no storage allocation at all
(free tier), or if ``current_bytes + additional_bytes`` would exceed
the tier's ``cloud_storage_gb`` limit.
"""
tier = self.get_tier(user_id)
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
return False # tier has no storage
if limit_gb == -1:
return True # unlimited
limit_bytes = limit_gb * 1024 ** 3
return current_bytes + additional_bytes <= limit_bytes
def enforce_quota(
self,
user_id: str,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if ``user_id`` would exceed their storage quota."""
tier = self.get_tier(user_id)
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Cloud storage is not available on the '{tier}' tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Storage quota exceeded for tier '{tier}'",
)
def enforce_backup_quota(
self,
user_id: str,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if ``user_id`` would exceed their backup quota."""
tier = self.get_tier(user_id)
limit_gb: int = FEATURES[tier]["backup_gb"]
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup is not available on the '{tier}' tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup quota exceeded for tier '{tier}'",
)
# Module-level singleton shared across the app.
tier_manager = TierManager()