step 11 complete: billing service and tier manager

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-02 22:41:35 +01:00
parent 8f7bc25611
commit 9787befd4a
7 changed files with 422 additions and 164 deletions

View File

@@ -16,6 +16,7 @@ from typing import Any
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.schemas import BackupMetadata, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
@@ -27,32 +28,11 @@ _blob_store = BlobStore()
# In-memory backup metadata — replaced by PostgreSQL backup_metadata table in Step 12
_backups: dict[str, list[dict[str, Any]]] = {} # user_id → list of backup records
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
_TIER_BACKUP_LIMITS_GB: dict[str, int] = {
"free": 0,
"pro": 5,
"power": 25,
"team": -1, # unlimited
}
def _check_backup_quota(user_id: str, tier: str, size_bytes: int) -> None:
def _check_backup_quota(user_id: str, size_bytes: int) -> None:
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
limit_gb = _TIER_BACKUP_LIMITS_GB.get(tier, 0)
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail="Backup is not available on the free tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024**3
used = sum(b["size_bytes"] for b in _backups.get(user_id, []))
if used + size_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup quota exceeded for tier '{tier}'",
)
current = sum(b["size_bytes"] for b in _backups.get(user_id, []))
tier_manager.enforce_backup_quota(user_id, current_bytes=current, additional_bytes=size_bytes)
@router.put("")
@@ -69,7 +49,7 @@ async def upload_backup(
"""
blob = await request.body()
reject_if_tampered(blob, x_backup_checksum)
_check_backup_quota(current_user.id, current_user.tier, len(blob))
_check_backup_quota(current_user.id, len(blob))
s3_key = await _blob_store.upload(
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum

View File

@@ -1,44 +1,23 @@
"""Billing routes: Stripe checkout, webhook, subscription management.
Subscription records are kept in-memory until Step 12 migrates them to
PostgreSQL (subscriptions table). Stripe calls are gracefully stubbed when
STRIPE_SECRET_KEY is not configured, allowing local development without keys.
Business logic lives in ``app.billing.stripe_service.StripeService``.
The route layer handles HTTP concerns (request parsing, response shaping)
and delegates everything else to the service singleton.
"""
from __future__ import annotations
from typing import Any
import stripe as stripe_lib
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from fastapi import APIRouter, Depends, Header, Request, status
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.config.settings import settings
from app.billing.stripe_service import stripe_service
from app.schemas import BillingTier, UserProfile
router = APIRouter(prefix="/billing", tags=["billing"])
# In-memory subscriptions — replaced by PostgreSQL subscriptions table in Step 12
_subscriptions: dict[str, dict[str, Any]] = {} # user_id → subscription record
_TIER_PRICE_IDS: dict[str, str] = {
"pro": "price_pro_monthly", # replace with real Stripe price IDs
"power": "price_power_monthly",
"team": "price_team_monthly",
}
# ── Helpers ────────────────────────────────────────────────────────────
def _stripe_configured() -> bool:
return bool(settings.STRIPE_SECRET_KEY)
def _stripe() -> Any:
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
return stripe_lib
# ── Request bodies ─────────────────────────────────────────────────────
@@ -57,34 +36,8 @@ async def create_checkout(
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
"""
if body.tier == "free":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot create a checkout session for the free tier",
)
if _stripe_configured():
price_id = _TIER_PRICE_IDS.get(body.tier)
if not price_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unknown tier: {body.tier}",
)
s = _stripe()
session = s.checkout.Session.create(
payment_method_types=["card"],
mode="subscription",
line_items=[{"price": price_id, "quantity": 1}],
success_url=(
"https://app.adiuva.app/billing/success"
"?session_id={CHECKOUT_SESSION_ID}"
),
cancel_url="https://app.adiuva.app/billing/cancel",
metadata={"user_id": current_user.id, "tier": body.tier},
)
return {"checkout_url": session.url}
return {"checkout_url": "https://stripe.com/stub-checkout"}
url = stripe_service.create_checkout_session(current_user.id, body.tier)
return {"checkout_url": url}
@router.post("/webhook", response_model=dict)
@@ -98,48 +51,7 @@ async def stripe_webhook(
Returns 200 immediately when Stripe is not configured (local dev).
"""
payload = await request.body()
if not _stripe_configured():
return {"ok": True}
try:
s = _stripe()
event = s.Webhook.construct_event(
payload, stripe_signature, settings.STRIPE_WEBHOOK_SECRET
)
except stripe_lib.error.SignatureVerificationError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid Stripe signature",
)
event_type: str = event["type"]
data: dict[str, Any] = event["data"]["object"]
if event_type == "checkout.session.completed":
user_id = data.get("metadata", {}).get("user_id")
tier = data.get("metadata", {}).get("tier", "free")
sub_id = data.get("subscription")
if user_id:
_subscriptions[user_id] = {
"tier": tier,
"stripe_subscription_id": sub_id,
"status": "active",
"current_period_end": None,
}
elif event_type == "customer.subscription.updated":
# TODO(Step12): look up user_id from stripe_customer_id in DB, then update tier
pass
elif event_type == "customer.subscription.deleted":
# TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free
pass
elif event_type == "invoice.payment_failed":
# TODO(Step12): flag subscription as past_due, notify user
pass
stripe_service.handle_webhook(payload, stripe_signature)
return {"ok": True}
@@ -148,7 +60,7 @@ async def get_subscription(
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, Any]:
"""Return the current subscription info for the authenticated user."""
sub = _subscriptions.get(current_user.id)
sub = stripe_service.get_subscription(current_user.id)
if sub is None:
return {
"tier": current_user.tier,
@@ -159,26 +71,10 @@ async def get_subscription(
return sub
@router.delete("/subscription", response_model=dict)
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
async def cancel_subscription(
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, bool]:
"""Cancel the active subscription."""
sub = _subscriptions.get(current_user.id)
if sub is None or not sub.get("stripe_subscription_id"):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No active subscription found",
)
if _stripe_configured():
s = _stripe()
s.Subscription.cancel(sub["stripe_subscription_id"])
_subscriptions[current_user.id] = {
**sub,
"tier": "free",
"status": "canceled",
}
stripe_service.cancel_subscription(current_user.id)
return {"ok": True}

View File

@@ -14,6 +14,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
@@ -25,14 +26,6 @@ _blob_store = BlobStore()
# In-memory record metadata — replaced by PostgreSQL storage_records table in Step 12
_records: dict[str, dict[str, Any]] = {}
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
_TIER_STORAGE_LIMITS_GB: dict[str, int] = {
"free": 0,
"pro": 5,
"power": 25,
"team": -1, # unlimited
}
# ── Local response schemas ─────────────────────────────────────────────
@@ -51,18 +44,10 @@ class _RecordMeta(BaseModel):
# ── Helpers ────────────────────────────────────────────────────────────
def _check_quota(user_id: str, tier: str, additional_bytes: int) -> None:
def _check_quota(user_id: str, additional_bytes: int) -> None:
"""Raise HTTP 402 if adding ``additional_bytes`` would exceed the tier limit."""
limit_gb = _TIER_STORAGE_LIMITS_GB.get(tier, 0)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024**3
used = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id)
if used + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Storage quota exceeded for tier '{tier}'",
)
current = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id)
tier_manager.enforce_quota(user_id, current_bytes=current, additional_bytes=additional_bytes)
def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]:
@@ -83,7 +68,7 @@ async def create_record(
) -> _CreateResponse:
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
reject_if_tampered(body.blob, body.checksum)
_check_quota(current_user.id, current_user.tier, len(body.blob))
_check_quota(current_user.id, len(body.blob))
record_id = str(uuid.uuid4())
now = int(time.time() * 1000)
@@ -159,7 +144,7 @@ async def update_record(
delta = len(body.blob) - record["size_bytes"]
if delta > 0:
_check_quota(current_user.id, current_user.tier, delta)
_check_quota(current_user.id, delta)
s3_key = await _blob_store.upload(
current_user.id, record["table"], record_id, body.blob, body.checksum