"""Tier manager: feature matrix and quota enforcement. ``TierManager`` is the single source of truth for what each billing tier allows. ``get_tier`` queries the ``subscriptions`` table for the live tier. Quota-enforcement helpers take ``tier`` directly — the caller already has it from ``current_user.tier`` (provided by ``get_current_user``). """ from __future__ import annotations from typing import Any from fastapi import HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.schemas import BillingTier # Feature matrix per tier. -1 means unlimited; 0 means disabled. FEATURES: dict[str, dict[str, Any]] = { "free": { "agents": 3, "batch_active": 2, "batch_runs_per_day": 5, "providers": 1, "batch_builder": False, "sso": False, }, "pro": { "agents": -1, # unlimited "batch_active": 10, "batch_runs_per_day": 50, "providers": -1, "batch_builder": False, "sso": False, }, "power": { "agents": -1, "batch_active": -1, # unlimited "batch_runs_per_day": -1, # unlimited "providers": -1, "batch_builder": True, "sso": False, }, "team": { "agents": -1, "batch_active": -1, "batch_runs_per_day": -1, # unlimited "providers": -1, "batch_builder": True, "sso": True, }, } # Requests-per-minute limit per tier. RATE_LIMITS: dict[str, int] = { "free": 20, "pro": 60, "power": 120, "team": 200, } class TierManager: """Centralises tier feature-gating, rate-limit lookups, and quota checks.""" # ── Tier lookup ───────────────────────────────────────────────────── async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier: """Return the current billing tier for ``user_id`` from the DB. Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod when no subscription row exists. """ from app.models import Subscription # noqa: PLC0415 from app.config.settings import settings # noqa: PLC0415 result = await db.execute( select(Subscription.tier).where(Subscription.user_id == user_id) ) tier: str | None = result.scalar_one_or_none() if tier is None or tier not in FEATURES: return "power" if settings.ENV == "dev" else "free" return tier # type: ignore[return-value] # ── Feature access ─────────────────────────────────────────────────── def check_feature(self, tier: BillingTier, feature: str) -> bool: """Return ``True`` if ``tier`` has ``feature`` enabled. For numeric features, any value > 0 or -1 (unlimited) counts as enabled. """ value = FEATURES.get(tier, FEATURES["free"]).get(feature) if value is None: return False if isinstance(value, bool): return value return value != 0 def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None: """Raise ``HTTP 403`` if ``tier`` does not have ``feature``.""" if not self.check_feature(tier, feature): detail = ( f"Feature '{feature}' requires {tier_name} tier or above." if tier_name else f"Feature '{feature}' is not available on your current tier." ) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail) # ── Rate limiting ──────────────────────────────────────────────────── def get_rate_limit(self, tier: BillingTier) -> int: """Return the requests-per-minute limit for ``tier``.""" return RATE_LIMITS.get(tier, RATE_LIMITS["free"]) # Module-level singleton shared across the app. tier_manager = TierManager()