step 12
This commit is contained in:
@@ -1,17 +1,19 @@
|
||||
"""Stripe service: checkout sessions, webhook handling, subscription management.
|
||||
|
||||
Subscriptions are stored in-memory until Step 12 migrates them to the
|
||||
PostgreSQL ``subscriptions`` table. All Stripe calls are gracefully stubbed
|
||||
when ``STRIPE_SECRET_KEY`` is not configured, enabling local development
|
||||
without live credentials.
|
||||
Subscription records are persisted in the PostgreSQL ``subscriptions`` table.
|
||||
All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not
|
||||
configured, enabling local development without live credentials.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import stripe as stripe_lib
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
@@ -24,15 +26,7 @@ TIER_PRICE_IDS: dict[str, str] = {
|
||||
|
||||
|
||||
class StripeService:
|
||||
"""Wraps all Stripe interactions and owns the in-memory subscription store.
|
||||
|
||||
Step 12 will replace ``_subscriptions`` with real PostgreSQL queries.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# user_id → subscription record dict
|
||||
# Replaced by the ``subscriptions`` table in Step 12.
|
||||
self._subscriptions: dict[str, dict[str, Any]] = {}
|
||||
"""Wraps all Stripe interactions and owns subscription persistence."""
|
||||
|
||||
# ── Internal helpers ────────────────────────────────────────────────
|
||||
|
||||
@@ -84,7 +78,12 @@ class StripeService:
|
||||
)
|
||||
return session.url
|
||||
|
||||
def handle_webhook(self, payload: bytes, sig_header: str) -> None:
|
||||
async def handle_webhook(
|
||||
self,
|
||||
payload: bytes,
|
||||
sig_header: str,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Process a Stripe webhook event.
|
||||
|
||||
Verifies the signature, then dispatches on event type.
|
||||
@@ -112,57 +111,82 @@ class StripeService:
|
||||
user_id = data.get("metadata", {}).get("user_id")
|
||||
tier = data.get("metadata", {}).get("tier", "free")
|
||||
sub_id = data.get("subscription")
|
||||
period_end = data.get("current_period_end")
|
||||
period_end_ts = data.get("current_period_end")
|
||||
period_end = (
|
||||
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||
if period_end_ts
|
||||
else None
|
||||
)
|
||||
if user_id:
|
||||
self._subscriptions[user_id] = {
|
||||
"tier": tier,
|
||||
"stripe_subscription_id": sub_id,
|
||||
"status": "active",
|
||||
"current_period_end": period_end,
|
||||
}
|
||||
await self._upsert_subscription(
|
||||
db, user_id, sub_id, tier, "active", period_end
|
||||
)
|
||||
|
||||
elif event_type == "customer.subscription.updated":
|
||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, update tier
|
||||
sub_id = data.get("id")
|
||||
new_status = data.get("status")
|
||||
period_end = data.get("current_period_end")
|
||||
for record in self._subscriptions.values():
|
||||
if record.get("stripe_subscription_id") == sub_id:
|
||||
record["status"] = new_status
|
||||
record["current_period_end"] = period_end
|
||||
break
|
||||
new_status = data.get("status", "active")
|
||||
period_end_ts = data.get("current_period_end")
|
||||
period_end = (
|
||||
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||
if period_end_ts
|
||||
else None
|
||||
)
|
||||
if sub_id:
|
||||
await self._update_subscription_by_stripe_id(
|
||||
db, sub_id, status=new_status, current_period_end=period_end
|
||||
)
|
||||
|
||||
elif event_type == "customer.subscription.deleted":
|
||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free
|
||||
sub_id = data.get("id")
|
||||
for user_id, record in self._subscriptions.items():
|
||||
if record.get("stripe_subscription_id") == sub_id:
|
||||
self._subscriptions[user_id] = {
|
||||
**record,
|
||||
"tier": "free",
|
||||
"status": "canceled",
|
||||
}
|
||||
break
|
||||
if sub_id:
|
||||
await self._update_subscription_by_stripe_id(
|
||||
db, sub_id, tier="free", status="canceled"
|
||||
)
|
||||
|
||||
elif event_type == "invoice.payment_failed":
|
||||
# TODO(Step12): flag subscription as past_due, notify user
|
||||
sub_id = data.get("subscription")
|
||||
for record in self._subscriptions.values():
|
||||
if record.get("stripe_subscription_id") == sub_id:
|
||||
record["status"] = "past_due"
|
||||
break
|
||||
if sub_id:
|
||||
await self._update_subscription_by_stripe_id(
|
||||
db, sub_id, status="past_due"
|
||||
)
|
||||
|
||||
def get_subscription(self, user_id: str) -> dict[str, Any] | None:
|
||||
await db.commit()
|
||||
|
||||
async def get_subscription(
|
||||
self, user_id: str, db: AsyncSession
|
||||
) -> dict[str, Any] | None:
|
||||
"""Return the subscription record for ``user_id``, or ``None`` if absent."""
|
||||
return self._subscriptions.get(user_id)
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
def cancel_subscription(self, user_id: str) -> None:
|
||||
result = await db.execute(
|
||||
select(Subscription).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None:
|
||||
return None
|
||||
return {
|
||||
"tier": sub.tier,
|
||||
"stripe_subscription_id": sub.stripe_subscription_id,
|
||||
"status": sub.status,
|
||||
"current_period_end": (
|
||||
int(sub.current_period_end.timestamp() * 1000)
|
||||
if sub.current_period_end
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
|
||||
"""Cancel the user's Stripe subscription and downgrade them to free.
|
||||
|
||||
Raises ``HTTP 404`` when no active subscription exists.
|
||||
"""
|
||||
sub = self._subscriptions.get(user_id)
|
||||
if sub is None or not sub.get("stripe_subscription_id"):
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None or not sub.stripe_subscription_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No active subscription found",
|
||||
@@ -170,13 +194,62 @@ class StripeService:
|
||||
|
||||
if self._configured():
|
||||
s = self._client()
|
||||
s.Subscription.cancel(sub["stripe_subscription_id"])
|
||||
s.Subscription.cancel(sub.stripe_subscription_id)
|
||||
|
||||
self._subscriptions[user_id] = {
|
||||
**sub,
|
||||
"tier": "free",
|
||||
"status": "canceled",
|
||||
}
|
||||
sub.tier = "free"
|
||||
sub.status = "canceled"
|
||||
await db.commit()
|
||||
|
||||
# ── Private DB helpers ───────────────────────────────────────────────
|
||||
|
||||
async def _upsert_subscription(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
stripe_subscription_id: str | None,
|
||||
tier: str,
|
||||
sub_status: str,
|
||||
current_period_end: datetime | None,
|
||||
) -> None:
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None:
|
||||
sub = Subscription(user_id=user_id)
|
||||
db.add(sub)
|
||||
sub.stripe_subscription_id = stripe_subscription_id
|
||||
sub.tier = tier
|
||||
sub.status = sub_status
|
||||
sub.current_period_end = current_period_end
|
||||
|
||||
async def _update_subscription_by_stripe_id(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
stripe_subscription_id: str,
|
||||
*,
|
||||
tier: str | None = None,
|
||||
status: str | None = None,
|
||||
current_period_end: datetime | None = None,
|
||||
) -> None:
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(
|
||||
Subscription.stripe_subscription_id == stripe_subscription_id
|
||||
)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None:
|
||||
return
|
||||
if tier is not None:
|
||||
sub.tier = tier
|
||||
if status is not None:
|
||||
sub.status = status
|
||||
if current_period_end is not None:
|
||||
sub.current_period_end = current_period_end
|
||||
|
||||
|
||||
# Module-level singleton shared across the app.
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Tier manager: feature matrix and quota enforcement.
|
||||
|
||||
``TierManager`` is the single source of truth for what each billing tier
|
||||
allows. ``get_tier`` reads from the ``StripeService`` in-memory store until
|
||||
Step 12 replaces it with a live PostgreSQL lookup.
|
||||
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
|
||||
Quota-enforcement helpers take ``tier`` directly — the caller already has it
|
||||
from ``current_user.tier`` (provided by ``get_current_user``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -10,6 +11,8 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.schemas import BillingTier
|
||||
|
||||
@@ -67,55 +70,42 @@ RATE_LIMITS: dict[str, int] = {
|
||||
|
||||
|
||||
class TierManager:
|
||||
"""Centralises tier feature-gating, rate-limit lookups, and quota checks.
|
||||
|
||||
``get_tier`` consults the ``StripeService`` singleton. Step 12 will
|
||||
replace that with a PostgreSQL query so that the tier is always fresh.
|
||||
"""
|
||||
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
|
||||
|
||||
# ── Tier lookup ─────────────────────────────────────────────────────
|
||||
|
||||
def get_tier(self, user_id: str) -> BillingTier:
|
||||
"""Return the current billing tier for ``user_id``.
|
||||
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||
"""Return the current billing tier for ``user_id`` from the DB.
|
||||
|
||||
Falls back to ``'free'`` when no subscription record exists.
|
||||
Step 12 will replace this with a live DB lookup.
|
||||
Falls back to ``'free'`` when no subscription row exists.
|
||||
"""
|
||||
# Import here to avoid circular imports at module load time.
|
||||
from app.billing.stripe_service import stripe_service # noqa: PLC0415
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
sub = stripe_service.get_subscription(user_id)
|
||||
if sub is None:
|
||||
return "free"
|
||||
tier = sub.get("tier", "free")
|
||||
# Validate against known tiers; unknown values fall back to free.
|
||||
if tier not in FEATURES:
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
tier: str | None = result.scalar_one_or_none()
|
||||
if tier is None or tier not in FEATURES:
|
||||
return "free"
|
||||
return tier # type: ignore[return-value]
|
||||
|
||||
# ── Feature access ───────────────────────────────────────────────────
|
||||
|
||||
def check_feature(self, user_id: str, feature: str) -> bool:
|
||||
"""Return ``True`` if ``user_id``'s current tier has ``feature`` enabled.
|
||||
def check_feature(self, tier: BillingTier, feature: str) -> bool:
|
||||
"""Return ``True`` if ``tier`` has ``feature`` enabled.
|
||||
|
||||
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
|
||||
"""
|
||||
tier = self.get_tier(user_id)
|
||||
value = FEATURES[tier].get(feature)
|
||||
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
# Numeric: -1 means unlimited (enabled), 0 means disabled.
|
||||
return value != 0
|
||||
|
||||
def require_feature(self, user_id: str, feature: str, tier_name: str = "") -> None:
|
||||
"""Raise ``HTTP 403`` if ``user_id`` does not have ``feature``.
|
||||
|
||||
``tier_name`` is used in the error message to tell users which tier
|
||||
they need to upgrade to.
|
||||
"""
|
||||
if not self.check_feature(user_id, feature):
|
||||
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
|
||||
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
|
||||
if not self.check_feature(tier, feature):
|
||||
detail = (
|
||||
f"Feature '{feature}' requires {tier_name} tier or above."
|
||||
if tier_name
|
||||
@@ -131,39 +121,17 @@ class TierManager:
|
||||
|
||||
# ── Storage quota ────────────────────────────────────────────────────
|
||||
|
||||
def check_quota(
|
||||
self,
|
||||
user_id: str,
|
||||
current_bytes: int = 0,
|
||||
additional_bytes: int = 0,
|
||||
) -> bool:
|
||||
"""Return ``True`` if ``user_id`` can store ``additional_bytes`` more data.
|
||||
|
||||
``current_bytes`` is the user's current storage usage (from the
|
||||
caller's record-keeping). Step 12 will remove these parameters and
|
||||
query the DB directly.
|
||||
|
||||
Returns ``False`` if the tier has no storage allocation at all
|
||||
(free tier), or if ``current_bytes + additional_bytes`` would exceed
|
||||
the tier's ``cloud_storage_gb`` limit.
|
||||
"""
|
||||
tier = self.get_tier(user_id)
|
||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||
if limit_gb == 0:
|
||||
return False # tier has no storage
|
||||
if limit_gb == -1:
|
||||
return True # unlimited
|
||||
limit_bytes = limit_gb * 1024 ** 3
|
||||
return current_bytes + additional_bytes <= limit_bytes
|
||||
|
||||
def enforce_quota(
|
||||
self,
|
||||
user_id: str,
|
||||
tier: BillingTier,
|
||||
current_bytes: int = 0,
|
||||
additional_bytes: int = 0,
|
||||
) -> None:
|
||||
"""Raise ``HTTP 402`` if ``user_id`` would exceed their storage quota."""
|
||||
tier = self.get_tier(user_id)
|
||||
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
|
||||
|
||||
``tier`` is the caller's current tier (from ``current_user.tier``).
|
||||
``current_bytes`` is the total bytes already stored (queried by caller).
|
||||
"""
|
||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||
if limit_gb == 0:
|
||||
raise HTTPException(
|
||||
@@ -181,12 +149,11 @@ class TierManager:
|
||||
|
||||
def enforce_backup_quota(
|
||||
self,
|
||||
user_id: str,
|
||||
tier: BillingTier,
|
||||
current_bytes: int = 0,
|
||||
additional_bytes: int = 0,
|
||||
) -> None:
|
||||
"""Raise ``HTTP 402`` if ``user_id`` would exceed their backup quota."""
|
||||
tier = self.get_tier(user_id)
|
||||
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
|
||||
limit_gb: int = FEATURES[tier]["backup_gb"]
|
||||
if limit_gb == 0:
|
||||
raise HTTPException(
|
||||
@@ -202,6 +169,21 @@ class TierManager:
|
||||
detail=f"Backup quota exceeded for tier '{tier}'",
|
||||
)
|
||||
|
||||
def check_quota(
|
||||
self,
|
||||
tier: BillingTier,
|
||||
current_bytes: int = 0,
|
||||
additional_bytes: int = 0,
|
||||
) -> bool:
|
||||
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
|
||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||
if limit_gb == 0:
|
||||
return False
|
||||
if limit_gb == -1:
|
||||
return True
|
||||
limit_bytes = limit_gb * 1024 ** 3
|
||||
return current_bytes + additional_bytes <= limit_bytes
|
||||
|
||||
|
||||
# Module-level singleton shared across the app.
|
||||
tier_manager = TierManager()
|
||||
|
||||
Reference in New Issue
Block a user