This commit is contained in:
2026-03-03 12:39:32 +01:00
parent 9787befd4a
commit 5d485b3665
12 changed files with 999 additions and 165 deletions

View File

@@ -1,17 +1,19 @@
"""Stripe service: checkout sessions, webhook handling, subscription management.
Subscriptions are stored in-memory until Step 12 migrates them to the
PostgreSQL ``subscriptions`` table. All Stripe calls are gracefully stubbed
when ``STRIPE_SECRET_KEY`` is not configured, enabling local development
without live credentials.
Subscription records are persisted in the PostgreSQL ``subscriptions`` table.
All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not
configured, enabling local development without live credentials.
"""
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
import stripe as stripe_lib
from fastapi import HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
@@ -24,15 +26,7 @@ TIER_PRICE_IDS: dict[str, str] = {
class StripeService:
"""Wraps all Stripe interactions and owns the in-memory subscription store.
Step 12 will replace ``_subscriptions`` with real PostgreSQL queries.
"""
def __init__(self) -> None:
# user_id → subscription record dict
# Replaced by the ``subscriptions`` table in Step 12.
self._subscriptions: dict[str, dict[str, Any]] = {}
"""Wraps all Stripe interactions and owns subscription persistence."""
# ── Internal helpers ────────────────────────────────────────────────
@@ -84,7 +78,12 @@ class StripeService:
)
return session.url
def handle_webhook(self, payload: bytes, sig_header: str) -> None:
async def handle_webhook(
self,
payload: bytes,
sig_header: str,
db: AsyncSession,
) -> None:
"""Process a Stripe webhook event.
Verifies the signature, then dispatches on event type.
@@ -112,57 +111,82 @@ class StripeService:
user_id = data.get("metadata", {}).get("user_id")
tier = data.get("metadata", {}).get("tier", "free")
sub_id = data.get("subscription")
period_end = data.get("current_period_end")
period_end_ts = data.get("current_period_end")
period_end = (
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
if period_end_ts
else None
)
if user_id:
self._subscriptions[user_id] = {
"tier": tier,
"stripe_subscription_id": sub_id,
"status": "active",
"current_period_end": period_end,
}
await self._upsert_subscription(
db, user_id, sub_id, tier, "active", period_end
)
elif event_type == "customer.subscription.updated":
# TODO(Step12): look up user_id from stripe_customer_id in DB, update tier
sub_id = data.get("id")
new_status = data.get("status")
period_end = data.get("current_period_end")
for record in self._subscriptions.values():
if record.get("stripe_subscription_id") == sub_id:
record["status"] = new_status
record["current_period_end"] = period_end
break
new_status = data.get("status", "active")
period_end_ts = data.get("current_period_end")
period_end = (
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
if period_end_ts
else None
)
if sub_id:
await self._update_subscription_by_stripe_id(
db, sub_id, status=new_status, current_period_end=period_end
)
elif event_type == "customer.subscription.deleted":
# TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free
sub_id = data.get("id")
for user_id, record in self._subscriptions.items():
if record.get("stripe_subscription_id") == sub_id:
self._subscriptions[user_id] = {
**record,
"tier": "free",
"status": "canceled",
}
break
if sub_id:
await self._update_subscription_by_stripe_id(
db, sub_id, tier="free", status="canceled"
)
elif event_type == "invoice.payment_failed":
# TODO(Step12): flag subscription as past_due, notify user
sub_id = data.get("subscription")
for record in self._subscriptions.values():
if record.get("stripe_subscription_id") == sub_id:
record["status"] = "past_due"
break
if sub_id:
await self._update_subscription_by_stripe_id(
db, sub_id, status="past_due"
)
def get_subscription(self, user_id: str) -> dict[str, Any] | None:
await db.commit()
async def get_subscription(
self, user_id: str, db: AsyncSession
) -> dict[str, Any] | None:
"""Return the subscription record for ``user_id``, or ``None`` if absent."""
return self._subscriptions.get(user_id)
from app.models import Subscription # noqa: PLC0415
def cancel_subscription(self, user_id: str) -> None:
result = await db.execute(
select(Subscription).where(Subscription.user_id == user_id)
)
sub = result.scalar_one_or_none()
if sub is None:
return None
return {
"tier": sub.tier,
"stripe_subscription_id": sub.stripe_subscription_id,
"status": sub.status,
"current_period_end": (
int(sub.current_period_end.timestamp() * 1000)
if sub.current_period_end
else None
),
}
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
"""Cancel the user's Stripe subscription and downgrade them to free.
Raises ``HTTP 404`` when no active subscription exists.
"""
sub = self._subscriptions.get(user_id)
if sub is None or not sub.get("stripe_subscription_id"):
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription).where(Subscription.user_id == user_id)
)
sub = result.scalar_one_or_none()
if sub is None or not sub.stripe_subscription_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No active subscription found",
@@ -170,13 +194,62 @@ class StripeService:
if self._configured():
s = self._client()
s.Subscription.cancel(sub["stripe_subscription_id"])
s.Subscription.cancel(sub.stripe_subscription_id)
self._subscriptions[user_id] = {
**sub,
"tier": "free",
"status": "canceled",
}
sub.tier = "free"
sub.status = "canceled"
await db.commit()
# ── Private DB helpers ───────────────────────────────────────────────
async def _upsert_subscription(
self,
db: AsyncSession,
user_id: str,
stripe_subscription_id: str | None,
tier: str,
sub_status: str,
current_period_end: datetime | None,
) -> None:
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription).where(Subscription.user_id == user_id)
)
sub = result.scalar_one_or_none()
if sub is None:
sub = Subscription(user_id=user_id)
db.add(sub)
sub.stripe_subscription_id = stripe_subscription_id
sub.tier = tier
sub.status = sub_status
sub.current_period_end = current_period_end
async def _update_subscription_by_stripe_id(
self,
db: AsyncSession,
stripe_subscription_id: str,
*,
tier: str | None = None,
status: str | None = None,
current_period_end: datetime | None = None,
) -> None:
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription).where(
Subscription.stripe_subscription_id == stripe_subscription_id
)
)
sub = result.scalar_one_or_none()
if sub is None:
return
if tier is not None:
sub.tier = tier
if status is not None:
sub.status = status
if current_period_end is not None:
sub.current_period_end = current_period_end
# Module-level singleton shared across the app.