190 lines
6.7 KiB
Python
190 lines
6.7 KiB
Python
"""Tier manager: feature matrix and quota enforcement.
|
|
|
|
``TierManager`` is the single source of truth for what each billing tier
|
|
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
|
|
Quota-enforcement helpers take ``tier`` directly — the caller already has it
|
|
from ``current_user.tier`` (provided by ``get_current_user``).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from fastapi import HTTPException, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.schemas import BillingTier
|
|
|
|
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
|
|
FEATURES: dict[str, dict[str, Any]] = {
|
|
"free": {
|
|
"agents": 3,
|
|
"batch_active": 2,
|
|
"cloud_storage_gb": 0,
|
|
"backup_gb": 0,
|
|
"providers": 1,
|
|
"batch_builder": False,
|
|
"plugin_marketplace": False,
|
|
"sso": False,
|
|
},
|
|
"pro": {
|
|
"agents": -1, # unlimited
|
|
"batch_active": 10,
|
|
"cloud_storage_gb": 5,
|
|
"backup_gb": 5,
|
|
"providers": -1,
|
|
"batch_builder": False,
|
|
"plugin_marketplace": False,
|
|
"sso": False,
|
|
},
|
|
"power": {
|
|
"agents": -1,
|
|
"batch_active": -1, # unlimited
|
|
"cloud_storage_gb": 25,
|
|
"backup_gb": 25,
|
|
"providers": -1,
|
|
"batch_builder": True,
|
|
"plugin_marketplace": True,
|
|
"sso": False,
|
|
},
|
|
"team": {
|
|
"agents": -1,
|
|
"batch_active": -1,
|
|
"cloud_storage_gb": -1, # unlimited
|
|
"backup_gb": -1, # unlimited
|
|
"providers": -1,
|
|
"batch_builder": True,
|
|
"plugin_marketplace": True,
|
|
"sso": True,
|
|
},
|
|
}
|
|
|
|
# Requests-per-minute limit per tier.
|
|
RATE_LIMITS: dict[str, int] = {
|
|
"free": 20,
|
|
"pro": 60,
|
|
"power": 120,
|
|
"team": 200,
|
|
}
|
|
|
|
|
|
class TierManager:
|
|
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
|
|
|
|
# ── Tier lookup ─────────────────────────────────────────────────────
|
|
|
|
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
|
"""Return the current billing tier for ``user_id`` from the DB.
|
|
|
|
Falls back to ``'free'`` when no subscription row exists.
|
|
"""
|
|
from app.models import Subscription # noqa: PLC0415
|
|
|
|
result = await db.execute(
|
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
|
)
|
|
tier: str | None = result.scalar_one_or_none()
|
|
if tier is None or tier not in FEATURES:
|
|
return "free"
|
|
return tier # type: ignore[return-value]
|
|
|
|
# ── Feature access ───────────────────────────────────────────────────
|
|
|
|
def check_feature(self, tier: BillingTier, feature: str) -> bool:
|
|
"""Return ``True`` if ``tier`` has ``feature`` enabled.
|
|
|
|
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
|
|
"""
|
|
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
|
if value is None:
|
|
return False
|
|
if isinstance(value, bool):
|
|
return value
|
|
return value != 0
|
|
|
|
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
|
|
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
|
|
if not self.check_feature(tier, feature):
|
|
detail = (
|
|
f"Feature '{feature}' requires {tier_name} tier or above."
|
|
if tier_name
|
|
else f"Feature '{feature}' is not available on your current tier."
|
|
)
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
|
|
|
# ── Rate limiting ────────────────────────────────────────────────────
|
|
|
|
def get_rate_limit(self, tier: BillingTier) -> int:
|
|
"""Return the requests-per-minute limit for ``tier``."""
|
|
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
|
|
|
# ── Storage quota ────────────────────────────────────────────────────
|
|
|
|
def enforce_quota(
|
|
self,
|
|
tier: BillingTier,
|
|
current_bytes: int = 0,
|
|
additional_bytes: int = 0,
|
|
) -> None:
|
|
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
|
|
|
|
``tier`` is the caller's current tier (from ``current_user.tier``).
|
|
``current_bytes`` is the total bytes already stored (queried by caller).
|
|
"""
|
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
|
if limit_gb == 0:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
detail=f"Cloud storage is not available on the '{tier}' tier",
|
|
)
|
|
if limit_gb == -1:
|
|
return # unlimited
|
|
limit_bytes = limit_gb * 1024 ** 3
|
|
if current_bytes + additional_bytes > limit_bytes:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
detail=f"Storage quota exceeded for tier '{tier}'",
|
|
)
|
|
|
|
def enforce_backup_quota(
|
|
self,
|
|
tier: BillingTier,
|
|
current_bytes: int = 0,
|
|
additional_bytes: int = 0,
|
|
) -> None:
|
|
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
|
|
limit_gb: int = FEATURES[tier]["backup_gb"]
|
|
if limit_gb == 0:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
detail=f"Backup is not available on the '{tier}' tier",
|
|
)
|
|
if limit_gb == -1:
|
|
return # unlimited
|
|
limit_bytes = limit_gb * 1024 ** 3
|
|
if current_bytes + additional_bytes > limit_bytes:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
detail=f"Backup quota exceeded for tier '{tier}'",
|
|
)
|
|
|
|
def check_quota(
|
|
self,
|
|
tier: BillingTier,
|
|
current_bytes: int = 0,
|
|
additional_bytes: int = 0,
|
|
) -> bool:
|
|
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
|
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
|
if limit_gb == 0:
|
|
return False
|
|
if limit_gb == -1:
|
|
return True
|
|
limit_bytes = limit_gb * 1024 ** 3
|
|
return current_bytes + additional_bytes <= limit_bytes
|
|
|
|
|
|
# Module-level singleton shared across the app.
|
|
tier_manager = TierManager()
|