Add folder_max_files and folder_monthly_tokens to all four tier dicts in FEATURES, and add get_feature_value() helper to TierManager. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
150 lines
5.7 KiB
Python
150 lines
5.7 KiB
Python
"""Tier manager: feature matrix and quota enforcement.
|
|
|
|
``TierManager`` is the single source of truth for what each billing tier
|
|
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
|
|
Quota-enforcement helpers take ``tier`` directly — the caller already has it
|
|
from ``current_user.tier`` (provided by ``get_current_user``).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from fastapi import HTTPException, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.schemas import BillingTier
|
|
|
|
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
|
|
FEATURES: dict[str, dict[str, Any]] = {
|
|
"free": {
|
|
"agents": 3,
|
|
"batch_active": 2,
|
|
"batch_runs_per_day": 5,
|
|
"providers": 1,
|
|
"batch_builder": False,
|
|
"sso": False,
|
|
"real_embeddings": False, # keyword fallback only
|
|
"realtime_extraction": False, # batch queue (Phase 2)
|
|
"relational_memory": False, # relational tier (Phase 3) — Pro+
|
|
"proactive_mining": False, # Power+ only (Phase 5)
|
|
"folder_max_files": 200,
|
|
"folder_monthly_tokens": 100_000,
|
|
},
|
|
"pro": {
|
|
"agents": -1, # unlimited
|
|
"batch_active": 10,
|
|
"batch_runs_per_day": 50,
|
|
"providers": -1,
|
|
"batch_builder": False,
|
|
"sso": False,
|
|
"real_embeddings": True, # pgvector cosine search
|
|
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
|
"relational_memory": True, # person/project predicates
|
|
"proactive_mining": False, # Power+ only (Phase 5)
|
|
"folder_max_files": 5000,
|
|
"folder_monthly_tokens": 2_000_000,
|
|
},
|
|
"power": {
|
|
"agents": -1,
|
|
"batch_active": -1, # unlimited
|
|
"batch_runs_per_day": -1, # unlimited
|
|
"providers": -1,
|
|
"batch_builder": True,
|
|
"sso": False,
|
|
"real_embeddings": True,
|
|
"realtime_extraction": True,
|
|
"relational_memory": True, # all predicates incl. custom
|
|
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
|
"folder_max_files": -1, # unlimited
|
|
"folder_monthly_tokens": -1, # unlimited
|
|
},
|
|
"team": {
|
|
"agents": -1,
|
|
"batch_active": -1,
|
|
"batch_runs_per_day": -1, # unlimited
|
|
"providers": -1,
|
|
"batch_builder": True,
|
|
"sso": True,
|
|
"real_embeddings": True,
|
|
"realtime_extraction": True,
|
|
"relational_memory": True, # all predicates incl. custom
|
|
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
|
"folder_max_files": -1, # unlimited
|
|
"folder_monthly_tokens": -1, # unlimited
|
|
},
|
|
}
|
|
|
|
# Requests-per-minute limit per tier.
|
|
RATE_LIMITS: dict[str, int] = {
|
|
"free": 20,
|
|
"pro": 60,
|
|
"power": 120,
|
|
"team": 200,
|
|
}
|
|
|
|
|
|
class TierManager:
|
|
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
|
|
|
|
# ── Tier lookup ─────────────────────────────────────────────────────
|
|
|
|
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
|
"""Return the current billing tier for ``user_id`` from the DB.
|
|
|
|
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod
|
|
when no subscription row exists.
|
|
"""
|
|
from app.models import Subscription # noqa: PLC0415
|
|
from app.config.settings import settings # noqa: PLC0415
|
|
|
|
result = await db.execute(
|
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
|
)
|
|
tier: str | None = result.scalar_one_or_none()
|
|
if tier is None or tier not in FEATURES:
|
|
return "power" if settings.ENV == "dev" else "free"
|
|
return tier # type: ignore[return-value]
|
|
|
|
# ── Feature access ───────────────────────────────────────────────────
|
|
|
|
def check_feature(self, tier: BillingTier, feature: str) -> bool:
|
|
"""Return ``True`` if ``tier`` has ``feature`` enabled.
|
|
|
|
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
|
|
"""
|
|
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
|
if value is None:
|
|
return False
|
|
if isinstance(value, bool):
|
|
return value
|
|
return value != 0
|
|
|
|
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
|
|
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
|
|
if not self.check_feature(tier, feature):
|
|
detail = (
|
|
f"Feature '{feature}' requires {tier_name} tier or above."
|
|
if tier_name
|
|
else f"Feature '{feature}' is not available on your current tier."
|
|
)
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
|
|
|
def get_feature_value(self, tier: BillingTier, feature: str) -> int:
|
|
"""Return integer feature value for tier. -1 means unlimited."""
|
|
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
|
if not isinstance(value, int):
|
|
return 0
|
|
return value
|
|
|
|
# ── Rate limiting ────────────────────────────────────────────────────
|
|
|
|
def get_rate_limit(self, tier: BillingTier) -> int:
|
|
"""Return the requests-per-minute limit for ``tier``."""
|
|
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
|
|
|
|
|
# Module-level singleton shared across the app.
|
|
tier_manager = TierManager()
|