- stripe_service: checkout sessions, webhook handling, subscription CRUD
- tier_manager: feature matrix (4 tiers), quota enforcement, rate limits
- routes: checkout, webhook (no auth), subscription, tier query, features
- Traefik header auth (X-User-Id) replaces get_current_user dependency
- /tier/{user_id} endpoint for internal service-to-service lookups
- /features and /features/{tier} for feature matrix queries
- Dockerfile: single worker, 30s timeout (lightweight service)
241 lines
8.2 KiB
Python
241 lines
8.2 KiB
Python
"""Stripe service: checkout sessions, webhook handling, subscription management.
|
|
|
|
Adapted for the Billing microservice — uses shared.models and shared.db.
|
|
All Stripe calls are gracefully stubbed when STRIPE_SECRET_KEY is not
|
|
configured, enabling local development without live credentials.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
import stripe as stripe_lib
|
|
from fastapi import HTTPException, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from shared.config import settings
|
|
from shared.models import Subscription
|
|
|
|
# Stripe price IDs per tier — replace with real IDs in production .env
|
|
TIER_PRICE_IDS: dict[str, str] = {
|
|
"pro": "price_pro_monthly",
|
|
"power": "price_power_monthly",
|
|
"team": "price_team_monthly",
|
|
}
|
|
|
|
|
|
class StripeService:
|
|
"""Wraps all Stripe interactions and owns subscription persistence."""
|
|
|
|
# ── Internal helpers ────────────────────────────────────────────────
|
|
|
|
def _configured(self) -> bool:
|
|
return bool(settings.STRIPE_SECRET_KEY)
|
|
|
|
def _client(self) -> Any:
|
|
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
|
return stripe_lib
|
|
|
|
# ── Public API ──────────────────────────────────────────────────────
|
|
|
|
def create_checkout_session(
|
|
self,
|
|
user_id: str,
|
|
tier: str,
|
|
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
|
cancel_url: str = "https://app.adiuva.app/billing/cancel",
|
|
) -> str:
|
|
"""Create a Stripe checkout session and return the URL."""
|
|
if tier == "free":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Cannot create a checkout session for the free tier",
|
|
)
|
|
|
|
price_id = TIER_PRICE_IDS.get(tier)
|
|
if not price_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Unknown tier: {tier}",
|
|
)
|
|
|
|
if not self._configured():
|
|
return "https://stripe.com/stub-checkout"
|
|
|
|
s = self._client()
|
|
session = s.checkout.Session.create(
|
|
payment_method_types=["card"],
|
|
mode="subscription",
|
|
line_items=[{"price": price_id, "quantity": 1}],
|
|
success_url=success_url,
|
|
cancel_url=cancel_url,
|
|
metadata={"user_id": user_id, "tier": tier},
|
|
)
|
|
return session.url
|
|
|
|
async def handle_webhook(
|
|
self,
|
|
payload: bytes,
|
|
sig_header: str,
|
|
db: AsyncSession,
|
|
) -> None:
|
|
"""Process a Stripe webhook event.
|
|
|
|
Verifies the signature, then dispatches on event type.
|
|
"""
|
|
if not self._configured():
|
|
return
|
|
|
|
try:
|
|
s = self._client()
|
|
event = s.Webhook.construct_event(
|
|
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
|
|
)
|
|
except stripe_lib.error.SignatureVerificationError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid Stripe signature",
|
|
)
|
|
|
|
event_type: str = event["type"]
|
|
data: dict[str, Any] = event["data"]["object"]
|
|
|
|
if event_type == "checkout.session.completed":
|
|
user_id = data.get("metadata", {}).get("user_id")
|
|
tier = data.get("metadata", {}).get("tier", "free")
|
|
sub_id = data.get("subscription")
|
|
period_end_ts = data.get("current_period_end")
|
|
period_end = (
|
|
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
|
if period_end_ts
|
|
else None
|
|
)
|
|
if user_id:
|
|
await self._upsert_subscription(
|
|
db, user_id, sub_id, tier, "active", period_end
|
|
)
|
|
|
|
elif event_type == "customer.subscription.updated":
|
|
sub_id = data.get("id")
|
|
new_status = data.get("status", "active")
|
|
period_end_ts = data.get("current_period_end")
|
|
period_end = (
|
|
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
|
if period_end_ts
|
|
else None
|
|
)
|
|
if sub_id:
|
|
await self._update_subscription_by_stripe_id(
|
|
db, sub_id, status=new_status, current_period_end=period_end
|
|
)
|
|
|
|
elif event_type == "customer.subscription.deleted":
|
|
sub_id = data.get("id")
|
|
if sub_id:
|
|
await self._update_subscription_by_stripe_id(
|
|
db, sub_id, tier="free", status="canceled"
|
|
)
|
|
|
|
elif event_type == "invoice.payment_failed":
|
|
sub_id = data.get("subscription")
|
|
if sub_id:
|
|
await self._update_subscription_by_stripe_id(
|
|
db, sub_id, status="past_due"
|
|
)
|
|
|
|
await db.commit()
|
|
|
|
async def get_subscription(
|
|
self, user_id: str, db: AsyncSession
|
|
) -> dict[str, Any] | None:
|
|
"""Return the subscription record for user_id, or None."""
|
|
result = await db.execute(
|
|
select(Subscription).where(Subscription.user_id == user_id)
|
|
)
|
|
sub = result.scalar_one_or_none()
|
|
if sub is None:
|
|
return None
|
|
return {
|
|
"tier": sub.tier,
|
|
"stripe_subscription_id": sub.stripe_subscription_id,
|
|
"status": sub.status,
|
|
"current_period_end": (
|
|
int(sub.current_period_end.timestamp() * 1000)
|
|
if sub.current_period_end
|
|
else None
|
|
),
|
|
}
|
|
|
|
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
|
|
"""Cancel the user's Stripe subscription and downgrade to free."""
|
|
result = await db.execute(
|
|
select(Subscription).where(Subscription.user_id == user_id)
|
|
)
|
|
sub = result.scalar_one_or_none()
|
|
if sub is None or not sub.stripe_subscription_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="No active subscription found",
|
|
)
|
|
|
|
if self._configured():
|
|
s = self._client()
|
|
s.Subscription.cancel(sub.stripe_subscription_id)
|
|
|
|
sub.tier = "free"
|
|
sub.status = "canceled"
|
|
await db.commit()
|
|
|
|
# ── Private DB helpers ───────────────────────────────────────────────
|
|
|
|
async def _upsert_subscription(
|
|
self,
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
stripe_subscription_id: str | None,
|
|
tier: str,
|
|
sub_status: str,
|
|
current_period_end: datetime | None,
|
|
) -> None:
|
|
result = await db.execute(
|
|
select(Subscription).where(Subscription.user_id == user_id)
|
|
)
|
|
sub = result.scalar_one_or_none()
|
|
if sub is None:
|
|
sub = Subscription(user_id=user_id)
|
|
db.add(sub)
|
|
sub.stripe_subscription_id = stripe_subscription_id
|
|
sub.tier = tier
|
|
sub.status = sub_status
|
|
sub.current_period_end = current_period_end
|
|
|
|
async def _update_subscription_by_stripe_id(
|
|
self,
|
|
db: AsyncSession,
|
|
stripe_subscription_id: str,
|
|
*,
|
|
tier: str | None = None,
|
|
status: str | None = None,
|
|
current_period_end: datetime | None = None,
|
|
) -> None:
|
|
result = await db.execute(
|
|
select(Subscription).where(
|
|
Subscription.stripe_subscription_id == stripe_subscription_id
|
|
)
|
|
)
|
|
sub = result.scalar_one_or_none()
|
|
if sub is None:
|
|
return
|
|
if tier is not None:
|
|
sub.tier = tier
|
|
if status is not None:
|
|
sub.status = status
|
|
if current_period_end is not None:
|
|
sub.current_period_end = current_period_end
|
|
|
|
|
|
# Module-level singleton
|
|
stripe_service = StripeService()
|