step 11 complete: billing service and tier manager
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -376,13 +376,13 @@ adiuva-api/
|
||||
- `async get_earnings(developer_id, period) -> dict`
|
||||
- **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split.
|
||||
|
||||
### Step 11 — Billing & Tier management
|
||||
- [ ] `app/billing/stripe_service.py`:
|
||||
### Step 11 — Billing & Tier management ✅
|
||||
- [x] `app/billing/stripe_service.py`:
|
||||
- `create_checkout_session(user_id, tier) -> str`
|
||||
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
|
||||
- `get_subscription(user_id) -> dict | None`
|
||||
- `cancel_subscription(user_id) -> None`
|
||||
- [ ] `app/billing/tier_manager.py`:
|
||||
- [x] `app/billing/tier_manager.py`:
|
||||
- `TierManager`:
|
||||
- Feature matrix:
|
||||
```python
|
||||
@@ -433,6 +433,9 @@ adiuva-api/
|
||||
- `check_feature(user_id, feature) -> bool`
|
||||
- `get_rate_limit(tier) -> int`
|
||||
- `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit
|
||||
- [x] `app/billing/__init__.py`: exports `stripe_service` and `tier_manager` singletons
|
||||
- [x] `app/api/routes/billing.py`: refactored to delegate to `StripeService`
|
||||
- [x] `app/api/routes/storage.py` and `backup.py`: `_check_quota` now delegates to `tier_manager.enforce_quota` / `enforce_backup_quota`
|
||||
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
|
||||
|
||||
### Step 12 — Database (auth/billing/marketplace only)
|
||||
|
||||
@@ -16,6 +16,7 @@ from typing import Any
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.tier_manager import tier_manager
|
||||
from app.schemas import BackupMetadata, UserProfile
|
||||
from app.storage.blob_store import BlobStore
|
||||
from app.storage.encryption import reject_if_tampered
|
||||
@@ -27,32 +28,11 @@ _blob_store = BlobStore()
|
||||
# In-memory backup metadata — replaced by PostgreSQL backup_metadata table in Step 12
|
||||
_backups: dict[str, list[dict[str, Any]]] = {} # user_id → list of backup records
|
||||
|
||||
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
|
||||
_TIER_BACKUP_LIMITS_GB: dict[str, int] = {
|
||||
"free": 0,
|
||||
"pro": 5,
|
||||
"power": 25,
|
||||
"team": -1, # unlimited
|
||||
}
|
||||
|
||||
|
||||
def _check_backup_quota(user_id: str, tier: str, size_bytes: int) -> None:
|
||||
def _check_backup_quota(user_id: str, size_bytes: int) -> None:
|
||||
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
|
||||
limit_gb = _TIER_BACKUP_LIMITS_GB.get(tier, 0)
|
||||
if limit_gb == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail="Backup is not available on the free tier",
|
||||
)
|
||||
if limit_gb == -1:
|
||||
return # unlimited
|
||||
limit_bytes = limit_gb * 1024**3
|
||||
used = sum(b["size_bytes"] for b in _backups.get(user_id, []))
|
||||
if used + size_bytes > limit_bytes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Backup quota exceeded for tier '{tier}'",
|
||||
)
|
||||
current = sum(b["size_bytes"] for b in _backups.get(user_id, []))
|
||||
tier_manager.enforce_backup_quota(user_id, current_bytes=current, additional_bytes=size_bytes)
|
||||
|
||||
|
||||
@router.put("")
|
||||
@@ -69,7 +49,7 @@ async def upload_backup(
|
||||
"""
|
||||
blob = await request.body()
|
||||
reject_if_tampered(blob, x_backup_checksum)
|
||||
_check_backup_quota(current_user.id, current_user.tier, len(blob))
|
||||
_check_backup_quota(current_user.id, len(blob))
|
||||
|
||||
s3_key = await _blob_store.upload(
|
||||
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
|
||||
|
||||
@@ -1,44 +1,23 @@
|
||||
"""Billing routes: Stripe checkout, webhook, subscription management.
|
||||
|
||||
Subscription records are kept in-memory until Step 12 migrates them to
|
||||
PostgreSQL (subscriptions table). Stripe calls are gracefully stubbed when
|
||||
STRIPE_SECRET_KEY is not configured, allowing local development without keys.
|
||||
Business logic lives in ``app.billing.stripe_service.StripeService``.
|
||||
The route layer handles HTTP concerns (request parsing, response shaping)
|
||||
and delegates everything else to the service singleton.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import stripe as stripe_lib
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||
from fastapi import APIRouter, Depends, Header, Request, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.config.settings import settings
|
||||
from app.billing.stripe_service import stripe_service
|
||||
from app.schemas import BillingTier, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||
|
||||
# In-memory subscriptions — replaced by PostgreSQL subscriptions table in Step 12
|
||||
_subscriptions: dict[str, dict[str, Any]] = {} # user_id → subscription record
|
||||
|
||||
_TIER_PRICE_IDS: dict[str, str] = {
|
||||
"pro": "price_pro_monthly", # replace with real Stripe price IDs
|
||||
"power": "price_power_monthly",
|
||||
"team": "price_team_monthly",
|
||||
}
|
||||
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
def _stripe_configured() -> bool:
|
||||
return bool(settings.STRIPE_SECRET_KEY)
|
||||
|
||||
|
||||
def _stripe() -> Any:
|
||||
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||
return stripe_lib
|
||||
|
||||
|
||||
# ── Request bodies ─────────────────────────────────────────────────────
|
||||
|
||||
@@ -57,34 +36,8 @@ async def create_checkout(
|
||||
|
||||
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
|
||||
"""
|
||||
if body.tier == "free":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot create a checkout session for the free tier",
|
||||
)
|
||||
|
||||
if _stripe_configured():
|
||||
price_id = _TIER_PRICE_IDS.get(body.tier)
|
||||
if not price_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unknown tier: {body.tier}",
|
||||
)
|
||||
s = _stripe()
|
||||
session = s.checkout.Session.create(
|
||||
payment_method_types=["card"],
|
||||
mode="subscription",
|
||||
line_items=[{"price": price_id, "quantity": 1}],
|
||||
success_url=(
|
||||
"https://app.adiuva.app/billing/success"
|
||||
"?session_id={CHECKOUT_SESSION_ID}"
|
||||
),
|
||||
cancel_url="https://app.adiuva.app/billing/cancel",
|
||||
metadata={"user_id": current_user.id, "tier": body.tier},
|
||||
)
|
||||
return {"checkout_url": session.url}
|
||||
|
||||
return {"checkout_url": "https://stripe.com/stub-checkout"}
|
||||
url = stripe_service.create_checkout_session(current_user.id, body.tier)
|
||||
return {"checkout_url": url}
|
||||
|
||||
|
||||
@router.post("/webhook", response_model=dict)
|
||||
@@ -98,48 +51,7 @@ async def stripe_webhook(
|
||||
Returns 200 immediately when Stripe is not configured (local dev).
|
||||
"""
|
||||
payload = await request.body()
|
||||
|
||||
if not _stripe_configured():
|
||||
return {"ok": True}
|
||||
|
||||
try:
|
||||
s = _stripe()
|
||||
event = s.Webhook.construct_event(
|
||||
payload, stripe_signature, settings.STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except stripe_lib.error.SignatureVerificationError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid Stripe signature",
|
||||
)
|
||||
|
||||
event_type: str = event["type"]
|
||||
data: dict[str, Any] = event["data"]["object"]
|
||||
|
||||
if event_type == "checkout.session.completed":
|
||||
user_id = data.get("metadata", {}).get("user_id")
|
||||
tier = data.get("metadata", {}).get("tier", "free")
|
||||
sub_id = data.get("subscription")
|
||||
if user_id:
|
||||
_subscriptions[user_id] = {
|
||||
"tier": tier,
|
||||
"stripe_subscription_id": sub_id,
|
||||
"status": "active",
|
||||
"current_period_end": None,
|
||||
}
|
||||
|
||||
elif event_type == "customer.subscription.updated":
|
||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, then update tier
|
||||
pass
|
||||
|
||||
elif event_type == "customer.subscription.deleted":
|
||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free
|
||||
pass
|
||||
|
||||
elif event_type == "invoice.payment_failed":
|
||||
# TODO(Step12): flag subscription as past_due, notify user
|
||||
pass
|
||||
|
||||
stripe_service.handle_webhook(payload, stripe_signature)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@@ -148,7 +60,7 @@ async def get_subscription(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, Any]:
|
||||
"""Return the current subscription info for the authenticated user."""
|
||||
sub = _subscriptions.get(current_user.id)
|
||||
sub = stripe_service.get_subscription(current_user.id)
|
||||
if sub is None:
|
||||
return {
|
||||
"tier": current_user.tier,
|
||||
@@ -159,26 +71,10 @@ async def get_subscription(
|
||||
return sub
|
||||
|
||||
|
||||
@router.delete("/subscription", response_model=dict)
|
||||
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
|
||||
async def cancel_subscription(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, bool]:
|
||||
"""Cancel the active subscription."""
|
||||
sub = _subscriptions.get(current_user.id)
|
||||
if sub is None or not sub.get("stripe_subscription_id"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No active subscription found",
|
||||
)
|
||||
|
||||
if _stripe_configured():
|
||||
s = _stripe()
|
||||
s.Subscription.cancel(sub["stripe_subscription_id"])
|
||||
|
||||
_subscriptions[current_user.id] = {
|
||||
**sub,
|
||||
"tier": "free",
|
||||
"status": "canceled",
|
||||
}
|
||||
|
||||
stripe_service.cancel_subscription(current_user.id)
|
||||
return {"ok": True}
|
||||
|
||||
@@ -14,6 +14,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.tier_manager import tier_manager
|
||||
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
|
||||
from app.storage.blob_store import BlobStore
|
||||
from app.storage.encryption import reject_if_tampered
|
||||
@@ -25,14 +26,6 @@ _blob_store = BlobStore()
|
||||
# In-memory record metadata — replaced by PostgreSQL storage_records table in Step 12
|
||||
_records: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
|
||||
_TIER_STORAGE_LIMITS_GB: dict[str, int] = {
|
||||
"free": 0,
|
||||
"pro": 5,
|
||||
"power": 25,
|
||||
"team": -1, # unlimited
|
||||
}
|
||||
|
||||
|
||||
# ── Local response schemas ─────────────────────────────────────────────
|
||||
|
||||
@@ -51,18 +44,10 @@ class _RecordMeta(BaseModel):
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
def _check_quota(user_id: str, tier: str, additional_bytes: int) -> None:
|
||||
def _check_quota(user_id: str, additional_bytes: int) -> None:
|
||||
"""Raise HTTP 402 if adding ``additional_bytes`` would exceed the tier limit."""
|
||||
limit_gb = _TIER_STORAGE_LIMITS_GB.get(tier, 0)
|
||||
if limit_gb == -1:
|
||||
return # unlimited
|
||||
limit_bytes = limit_gb * 1024**3
|
||||
used = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id)
|
||||
if used + additional_bytes > limit_bytes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Storage quota exceeded for tier '{tier}'",
|
||||
)
|
||||
current = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id)
|
||||
tier_manager.enforce_quota(user_id, current_bytes=current, additional_bytes=additional_bytes)
|
||||
|
||||
|
||||
def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]:
|
||||
@@ -83,7 +68,7 @@ async def create_record(
|
||||
) -> _CreateResponse:
|
||||
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
|
||||
reject_if_tampered(body.blob, body.checksum)
|
||||
_check_quota(current_user.id, current_user.tier, len(body.blob))
|
||||
_check_quota(current_user.id, len(body.blob))
|
||||
|
||||
record_id = str(uuid.uuid4())
|
||||
now = int(time.time() * 1000)
|
||||
@@ -159,7 +144,7 @@ async def update_record(
|
||||
|
||||
delta = len(body.blob) - record["size_bytes"]
|
||||
if delta > 0:
|
||||
_check_quota(current_user.id, current_user.tier, delta)
|
||||
_check_quota(current_user.id, delta)
|
||||
|
||||
s3_key = await _blob_store.upload(
|
||||
current_user.id, record["table"], record_id, body.blob, body.checksum
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from app.billing.stripe_service import stripe_service
|
||||
from app.billing.tier_manager import tier_manager
|
||||
|
||||
__all__ = ["stripe_service", "tier_manager"]
|
||||
|
||||
183
app/billing/stripe_service.py
Normal file
183
app/billing/stripe_service.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Stripe service: checkout sessions, webhook handling, subscription management.
|
||||
|
||||
Subscriptions are stored in-memory until Step 12 migrates them to the
|
||||
PostgreSQL ``subscriptions`` table. All Stripe calls are gracefully stubbed
|
||||
when ``STRIPE_SECRET_KEY`` is not configured, enabling local development
|
||||
without live credentials.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import stripe as stripe_lib
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
# Stripe price IDs per tier — replace with real IDs in production .env
|
||||
TIER_PRICE_IDS: dict[str, str] = {
|
||||
"pro": "price_pro_monthly",
|
||||
"power": "price_power_monthly",
|
||||
"team": "price_team_monthly",
|
||||
}
|
||||
|
||||
|
||||
class StripeService:
|
||||
"""Wraps all Stripe interactions and owns the in-memory subscription store.
|
||||
|
||||
Step 12 will replace ``_subscriptions`` with real PostgreSQL queries.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# user_id → subscription record dict
|
||||
# Replaced by the ``subscriptions`` table in Step 12.
|
||||
self._subscriptions: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# ── Internal helpers ────────────────────────────────────────────────
|
||||
|
||||
def _configured(self) -> bool:
|
||||
return bool(settings.STRIPE_SECRET_KEY)
|
||||
|
||||
def _client(self) -> Any:
|
||||
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||
return stripe_lib
|
||||
|
||||
# ── Public API ──────────────────────────────────────────────────────
|
||||
|
||||
def create_checkout_session(
|
||||
self,
|
||||
user_id: str,
|
||||
tier: str,
|
||||
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
||||
cancel_url: str = "https://app.adiuva.app/billing/cancel",
|
||||
) -> str:
|
||||
"""Create a Stripe checkout session and return the URL.
|
||||
|
||||
Returns a stub URL when Stripe is not configured.
|
||||
Raises ``HTTP 400`` for the free tier or an unknown tier.
|
||||
"""
|
||||
if tier == "free":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot create a checkout session for the free tier",
|
||||
)
|
||||
|
||||
price_id = TIER_PRICE_IDS.get(tier)
|
||||
if not price_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unknown tier: {tier}",
|
||||
)
|
||||
|
||||
if not self._configured():
|
||||
return "https://stripe.com/stub-checkout"
|
||||
|
||||
s = self._client()
|
||||
session = s.checkout.Session.create(
|
||||
payment_method_types=["card"],
|
||||
mode="subscription",
|
||||
line_items=[{"price": price_id, "quantity": 1}],
|
||||
success_url=success_url,
|
||||
cancel_url=cancel_url,
|
||||
metadata={"user_id": user_id, "tier": tier},
|
||||
)
|
||||
return session.url
|
||||
|
||||
def handle_webhook(self, payload: bytes, sig_header: str) -> None:
|
||||
"""Process a Stripe webhook event.
|
||||
|
||||
Verifies the signature, then dispatches on event type.
|
||||
Raises ``HTTP 400`` on signature mismatch.
|
||||
No-ops when Stripe is not configured.
|
||||
"""
|
||||
if not self._configured():
|
||||
return
|
||||
|
||||
try:
|
||||
s = self._client()
|
||||
event = s.Webhook.construct_event(
|
||||
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except stripe_lib.error.SignatureVerificationError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid Stripe signature",
|
||||
)
|
||||
|
||||
event_type: str = event["type"]
|
||||
data: dict[str, Any] = event["data"]["object"]
|
||||
|
||||
if event_type == "checkout.session.completed":
|
||||
user_id = data.get("metadata", {}).get("user_id")
|
||||
tier = data.get("metadata", {}).get("tier", "free")
|
||||
sub_id = data.get("subscription")
|
||||
period_end = data.get("current_period_end")
|
||||
if user_id:
|
||||
self._subscriptions[user_id] = {
|
||||
"tier": tier,
|
||||
"stripe_subscription_id": sub_id,
|
||||
"status": "active",
|
||||
"current_period_end": period_end,
|
||||
}
|
||||
|
||||
elif event_type == "customer.subscription.updated":
|
||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, update tier
|
||||
sub_id = data.get("id")
|
||||
new_status = data.get("status")
|
||||
period_end = data.get("current_period_end")
|
||||
for record in self._subscriptions.values():
|
||||
if record.get("stripe_subscription_id") == sub_id:
|
||||
record["status"] = new_status
|
||||
record["current_period_end"] = period_end
|
||||
break
|
||||
|
||||
elif event_type == "customer.subscription.deleted":
|
||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free
|
||||
sub_id = data.get("id")
|
||||
for user_id, record in self._subscriptions.items():
|
||||
if record.get("stripe_subscription_id") == sub_id:
|
||||
self._subscriptions[user_id] = {
|
||||
**record,
|
||||
"tier": "free",
|
||||
"status": "canceled",
|
||||
}
|
||||
break
|
||||
|
||||
elif event_type == "invoice.payment_failed":
|
||||
# TODO(Step12): flag subscription as past_due, notify user
|
||||
sub_id = data.get("subscription")
|
||||
for record in self._subscriptions.values():
|
||||
if record.get("stripe_subscription_id") == sub_id:
|
||||
record["status"] = "past_due"
|
||||
break
|
||||
|
||||
def get_subscription(self, user_id: str) -> dict[str, Any] | None:
|
||||
"""Return the subscription record for ``user_id``, or ``None`` if absent."""
|
||||
return self._subscriptions.get(user_id)
|
||||
|
||||
def cancel_subscription(self, user_id: str) -> None:
|
||||
"""Cancel the user's Stripe subscription and downgrade them to free.
|
||||
|
||||
Raises ``HTTP 404`` when no active subscription exists.
|
||||
"""
|
||||
sub = self._subscriptions.get(user_id)
|
||||
if sub is None or not sub.get("stripe_subscription_id"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No active subscription found",
|
||||
)
|
||||
|
||||
if self._configured():
|
||||
s = self._client()
|
||||
s.Subscription.cancel(sub["stripe_subscription_id"])
|
||||
|
||||
self._subscriptions[user_id] = {
|
||||
**sub,
|
||||
"tier": "free",
|
||||
"status": "canceled",
|
||||
}
|
||||
|
||||
|
||||
# Module-level singleton shared across the app.
|
||||
stripe_service = StripeService()
|
||||
207
app/billing/tier_manager.py
Normal file
207
app/billing/tier_manager.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Tier manager: feature matrix and quota enforcement.
|
||||
|
||||
``TierManager`` is the single source of truth for what each billing tier
|
||||
allows. ``get_tier`` reads from the ``StripeService`` in-memory store until
|
||||
Step 12 replaces it with a live PostgreSQL lookup.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.schemas import BillingTier
|
||||
|
||||
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
|
||||
FEATURES: dict[str, dict[str, Any]] = {
|
||||
"free": {
|
||||
"agents": 3,
|
||||
"batch_active": 2,
|
||||
"cloud_storage_gb": 0,
|
||||
"backup_gb": 0,
|
||||
"providers": 1,
|
||||
"batch_builder": False,
|
||||
"plugin_marketplace": False,
|
||||
"sso": False,
|
||||
},
|
||||
"pro": {
|
||||
"agents": -1, # unlimited
|
||||
"batch_active": 10,
|
||||
"cloud_storage_gb": 5,
|
||||
"backup_gb": 5,
|
||||
"providers": -1,
|
||||
"batch_builder": False,
|
||||
"plugin_marketplace": False,
|
||||
"sso": False,
|
||||
},
|
||||
"power": {
|
||||
"agents": -1,
|
||||
"batch_active": -1, # unlimited
|
||||
"cloud_storage_gb": 25,
|
||||
"backup_gb": 25,
|
||||
"providers": -1,
|
||||
"batch_builder": True,
|
||||
"plugin_marketplace": True,
|
||||
"sso": False,
|
||||
},
|
||||
"team": {
|
||||
"agents": -1,
|
||||
"batch_active": -1,
|
||||
"cloud_storage_gb": -1, # unlimited
|
||||
"backup_gb": -1, # unlimited
|
||||
"providers": -1,
|
||||
"batch_builder": True,
|
||||
"plugin_marketplace": True,
|
||||
"sso": True,
|
||||
},
|
||||
}
|
||||
|
||||
# Requests-per-minute limit per tier.
|
||||
RATE_LIMITS: dict[str, int] = {
|
||||
"free": 20,
|
||||
"pro": 60,
|
||||
"power": 120,
|
||||
"team": 200,
|
||||
}
|
||||
|
||||
|
||||
class TierManager:
|
||||
"""Centralises tier feature-gating, rate-limit lookups, and quota checks.
|
||||
|
||||
``get_tier`` consults the ``StripeService`` singleton. Step 12 will
|
||||
replace that with a PostgreSQL query so that the tier is always fresh.
|
||||
"""
|
||||
|
||||
# ── Tier lookup ─────────────────────────────────────────────────────
|
||||
|
||||
def get_tier(self, user_id: str) -> BillingTier:
|
||||
"""Return the current billing tier for ``user_id``.
|
||||
|
||||
Falls back to ``'free'`` when no subscription record exists.
|
||||
Step 12 will replace this with a live DB lookup.
|
||||
"""
|
||||
# Import here to avoid circular imports at module load time.
|
||||
from app.billing.stripe_service import stripe_service # noqa: PLC0415
|
||||
|
||||
sub = stripe_service.get_subscription(user_id)
|
||||
if sub is None:
|
||||
return "free"
|
||||
tier = sub.get("tier", "free")
|
||||
# Validate against known tiers; unknown values fall back to free.
|
||||
if tier not in FEATURES:
|
||||
return "free"
|
||||
return tier # type: ignore[return-value]
|
||||
|
||||
# ── Feature access ───────────────────────────────────────────────────
|
||||
|
||||
def check_feature(self, user_id: str, feature: str) -> bool:
|
||||
"""Return ``True`` if ``user_id``'s current tier has ``feature`` enabled.
|
||||
|
||||
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
|
||||
"""
|
||||
tier = self.get_tier(user_id)
|
||||
value = FEATURES[tier].get(feature)
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
# Numeric: -1 means unlimited (enabled), 0 means disabled.
|
||||
return value != 0
|
||||
|
||||
def require_feature(self, user_id: str, feature: str, tier_name: str = "") -> None:
|
||||
"""Raise ``HTTP 403`` if ``user_id`` does not have ``feature``.
|
||||
|
||||
``tier_name`` is used in the error message to tell users which tier
|
||||
they need to upgrade to.
|
||||
"""
|
||||
if not self.check_feature(user_id, feature):
|
||||
detail = (
|
||||
f"Feature '{feature}' requires {tier_name} tier or above."
|
||||
if tier_name
|
||||
else f"Feature '{feature}' is not available on your current tier."
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
# ── Rate limiting ────────────────────────────────────────────────────
|
||||
|
||||
def get_rate_limit(self, tier: BillingTier) -> int:
|
||||
"""Return the requests-per-minute limit for ``tier``."""
|
||||
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
||||
|
||||
# ── Storage quota ────────────────────────────────────────────────────
|
||||
|
||||
def check_quota(
|
||||
self,
|
||||
user_id: str,
|
||||
current_bytes: int = 0,
|
||||
additional_bytes: int = 0,
|
||||
) -> bool:
|
||||
"""Return ``True`` if ``user_id`` can store ``additional_bytes`` more data.
|
||||
|
||||
``current_bytes`` is the user's current storage usage (from the
|
||||
caller's record-keeping). Step 12 will remove these parameters and
|
||||
query the DB directly.
|
||||
|
||||
Returns ``False`` if the tier has no storage allocation at all
|
||||
(free tier), or if ``current_bytes + additional_bytes`` would exceed
|
||||
the tier's ``cloud_storage_gb`` limit.
|
||||
"""
|
||||
tier = self.get_tier(user_id)
|
||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||
if limit_gb == 0:
|
||||
return False # tier has no storage
|
||||
if limit_gb == -1:
|
||||
return True # unlimited
|
||||
limit_bytes = limit_gb * 1024 ** 3
|
||||
return current_bytes + additional_bytes <= limit_bytes
|
||||
|
||||
def enforce_quota(
|
||||
self,
|
||||
user_id: str,
|
||||
current_bytes: int = 0,
|
||||
additional_bytes: int = 0,
|
||||
) -> None:
|
||||
"""Raise ``HTTP 402`` if ``user_id`` would exceed their storage quota."""
|
||||
tier = self.get_tier(user_id)
|
||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||
if limit_gb == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Cloud storage is not available on the '{tier}' tier",
|
||||
)
|
||||
if limit_gb == -1:
|
||||
return # unlimited
|
||||
limit_bytes = limit_gb * 1024 ** 3
|
||||
if current_bytes + additional_bytes > limit_bytes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Storage quota exceeded for tier '{tier}'",
|
||||
)
|
||||
|
||||
def enforce_backup_quota(
|
||||
self,
|
||||
user_id: str,
|
||||
current_bytes: int = 0,
|
||||
additional_bytes: int = 0,
|
||||
) -> None:
|
||||
"""Raise ``HTTP 402`` if ``user_id`` would exceed their backup quota."""
|
||||
tier = self.get_tier(user_id)
|
||||
limit_gb: int = FEATURES[tier]["backup_gb"]
|
||||
if limit_gb == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Backup is not available on the '{tier}' tier",
|
||||
)
|
||||
if limit_gb == -1:
|
||||
return # unlimited
|
||||
limit_bytes = limit_gb * 1024 ** 3
|
||||
if current_bytes + additional_bytes > limit_bytes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Backup quota exceeded for tier '{tier}'",
|
||||
)
|
||||
|
||||
|
||||
# Module-level singleton shared across the app.
|
||||
tier_manager = TierManager()
|
||||
Reference in New Issue
Block a user