"""Stripe service: checkout sessions, webhook handling, subscription management. Subscription records are persisted in the PostgreSQL ``subscriptions`` table. All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not configured, enabling local development without live credentials. """ from __future__ import annotations from datetime import datetime, timezone from typing import Any import stripe as stripe_lib from fastapi import HTTPException, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.config.settings import settings # Stripe price IDs per tier — replace with real IDs in production .env TIER_PRICE_IDS: dict[str, str] = { "pro": "price_pro_monthly", "power": "price_power_monthly", "team": "price_team_monthly", } class StripeService: """Wraps all Stripe interactions and owns subscription persistence.""" # ── Internal helpers ──────────────────────────────────────────────── def _configured(self) -> bool: return bool(settings.STRIPE_SECRET_KEY) def _client(self) -> Any: stripe_lib.api_key = settings.STRIPE_SECRET_KEY return stripe_lib # ── Public API ────────────────────────────────────────────────────── def create_checkout_session( self, user_id: str, tier: str, success_url: str = "https://app.adiuvai.app/billing/success?session_id={CHECKOUT_SESSION_ID}", cancel_url: str = "https://app.adiuvai.app/billing/cancel", ) -> str: """Create a Stripe checkout session and return the URL. Returns a stub URL when Stripe is not configured. Raises ``HTTP 400`` for the free tier or an unknown tier. """ if tier == "free": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot create a checkout session for the free tier", ) price_id = TIER_PRICE_IDS.get(tier) if not price_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unknown tier: {tier}", ) if not self._configured(): return "https://stripe.com/stub-checkout" s = self._client() session = s.checkout.Session.create( payment_method_types=["card"], mode="subscription", line_items=[{"price": price_id, "quantity": 1}], success_url=success_url, cancel_url=cancel_url, metadata={"user_id": user_id, "tier": tier}, ) return session.url async def handle_webhook( self, payload: bytes, sig_header: str, db: AsyncSession, ) -> None: """Process a Stripe webhook event. Verifies the signature, then dispatches on event type. Raises ``HTTP 400`` on signature mismatch. No-ops when Stripe is not configured. """ if not self._configured(): return try: s = self._client() event = s.Webhook.construct_event( payload, sig_header, settings.STRIPE_WEBHOOK_SECRET ) except stripe_lib.error.SignatureVerificationError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid Stripe signature", ) event_type: str = event["type"] data: dict[str, Any] = event["data"]["object"] if event_type == "checkout.session.completed": user_id = data.get("metadata", {}).get("user_id") tier = data.get("metadata", {}).get("tier", "free") sub_id = data.get("subscription") period_end_ts = data.get("current_period_end") period_end = ( datetime.fromtimestamp(period_end_ts, tz=timezone.utc) if period_end_ts else None ) if user_id: await self._upsert_subscription( db, user_id, sub_id, tier, "active", period_end ) elif event_type == "customer.subscription.updated": sub_id = data.get("id") new_status = data.get("status", "active") period_end_ts = data.get("current_period_end") period_end = ( datetime.fromtimestamp(period_end_ts, tz=timezone.utc) if period_end_ts else None ) if sub_id: await self._update_subscription_by_stripe_id( db, sub_id, status=new_status, current_period_end=period_end ) elif event_type == "customer.subscription.deleted": sub_id = data.get("id") if sub_id: await self._update_subscription_by_stripe_id( db, sub_id, tier="free", status="canceled" ) elif event_type == "invoice.payment_failed": sub_id = data.get("subscription") if sub_id: await self._update_subscription_by_stripe_id( db, sub_id, status="past_due" ) await db.commit() async def get_subscription( self, user_id: str, db: AsyncSession ) -> dict[str, Any] | None: """Return the subscription record for ``user_id``, or ``None`` if absent.""" from app.models import Subscription # noqa: PLC0415 result = await db.execute( select(Subscription).where(Subscription.user_id == user_id) ) sub = result.scalar_one_or_none() if sub is None: return None return { "tier": sub.tier, "stripe_subscription_id": sub.stripe_subscription_id, "status": sub.status, "current_period_end": ( int(sub.current_period_end.timestamp() * 1000) if sub.current_period_end else None ), } async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None: """Cancel the user's Stripe subscription and downgrade them to free. Raises ``HTTP 404`` when no active subscription exists. """ from app.models import Subscription # noqa: PLC0415 result = await db.execute( select(Subscription).where(Subscription.user_id == user_id) ) sub = result.scalar_one_or_none() if sub is None or not sub.stripe_subscription_id: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="No active subscription found", ) if self._configured(): s = self._client() s.Subscription.cancel(sub.stripe_subscription_id) sub.tier = "free" sub.status = "canceled" await db.commit() # ── Private DB helpers ─────────────────────────────────────────────── async def _upsert_subscription( self, db: AsyncSession, user_id: str, stripe_subscription_id: str | None, tier: str, sub_status: str, current_period_end: datetime | None, ) -> None: from app.models import Subscription # noqa: PLC0415 result = await db.execute( select(Subscription).where(Subscription.user_id == user_id) ) sub = result.scalar_one_or_none() if sub is None: sub = Subscription(user_id=user_id) db.add(sub) sub.stripe_subscription_id = stripe_subscription_id sub.tier = tier sub.status = sub_status sub.current_period_end = current_period_end async def _update_subscription_by_stripe_id( self, db: AsyncSession, stripe_subscription_id: str, *, tier: str | None = None, status: str | None = None, current_period_end: datetime | None = None, ) -> None: from app.models import Subscription # noqa: PLC0415 result = await db.execute( select(Subscription).where( Subscription.stripe_subscription_id == stripe_subscription_id ) ) sub = result.scalar_one_or_none() if sub is None: return if tier is not None: sub.tier = tier if status is not None: sub.status = status if current_period_end is not None: sub.current_period_end = current_period_end # Module-level singleton shared across the app. stripe_service = StripeService()