"""Stripe service: checkout sessions, webhook handling, subscription management. Subscriptions are stored in-memory until Step 12 migrates them to the PostgreSQL ``subscriptions`` table. All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not configured, enabling local development without live credentials. """ from __future__ import annotations from typing import Any import stripe as stripe_lib from fastapi import HTTPException, status from app.config.settings import settings # Stripe price IDs per tier — replace with real IDs in production .env TIER_PRICE_IDS: dict[str, str] = { "pro": "price_pro_monthly", "power": "price_power_monthly", "team": "price_team_monthly", } class StripeService: """Wraps all Stripe interactions and owns the in-memory subscription store. Step 12 will replace ``_subscriptions`` with real PostgreSQL queries. """ def __init__(self) -> None: # user_id → subscription record dict # Replaced by the ``subscriptions`` table in Step 12. self._subscriptions: dict[str, dict[str, Any]] = {} # ── Internal helpers ──────────────────────────────────────────────── def _configured(self) -> bool: return bool(settings.STRIPE_SECRET_KEY) def _client(self) -> Any: stripe_lib.api_key = settings.STRIPE_SECRET_KEY return stripe_lib # ── Public API ────────────────────────────────────────────────────── def create_checkout_session( self, user_id: str, tier: str, success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}", cancel_url: str = "https://app.adiuva.app/billing/cancel", ) -> str: """Create a Stripe checkout session and return the URL. Returns a stub URL when Stripe is not configured. Raises ``HTTP 400`` for the free tier or an unknown tier. """ if tier == "free": raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot create a checkout session for the free tier", ) price_id = TIER_PRICE_IDS.get(tier) if not price_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unknown tier: {tier}", ) if not self._configured(): return "https://stripe.com/stub-checkout" s = self._client() session = s.checkout.Session.create( payment_method_types=["card"], mode="subscription", line_items=[{"price": price_id, "quantity": 1}], success_url=success_url, cancel_url=cancel_url, metadata={"user_id": user_id, "tier": tier}, ) return session.url def handle_webhook(self, payload: bytes, sig_header: str) -> None: """Process a Stripe webhook event. Verifies the signature, then dispatches on event type. Raises ``HTTP 400`` on signature mismatch. No-ops when Stripe is not configured. """ if not self._configured(): return try: s = self._client() event = s.Webhook.construct_event( payload, sig_header, settings.STRIPE_WEBHOOK_SECRET ) except stripe_lib.error.SignatureVerificationError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid Stripe signature", ) event_type: str = event["type"] data: dict[str, Any] = event["data"]["object"] if event_type == "checkout.session.completed": user_id = data.get("metadata", {}).get("user_id") tier = data.get("metadata", {}).get("tier", "free") sub_id = data.get("subscription") period_end = data.get("current_period_end") if user_id: self._subscriptions[user_id] = { "tier": tier, "stripe_subscription_id": sub_id, "status": "active", "current_period_end": period_end, } elif event_type == "customer.subscription.updated": # TODO(Step12): look up user_id from stripe_customer_id in DB, update tier sub_id = data.get("id") new_status = data.get("status") period_end = data.get("current_period_end") for record in self._subscriptions.values(): if record.get("stripe_subscription_id") == sub_id: record["status"] = new_status record["current_period_end"] = period_end break elif event_type == "customer.subscription.deleted": # TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free sub_id = data.get("id") for user_id, record in self._subscriptions.items(): if record.get("stripe_subscription_id") == sub_id: self._subscriptions[user_id] = { **record, "tier": "free", "status": "canceled", } break elif event_type == "invoice.payment_failed": # TODO(Step12): flag subscription as past_due, notify user sub_id = data.get("subscription") for record in self._subscriptions.values(): if record.get("stripe_subscription_id") == sub_id: record["status"] = "past_due" break def get_subscription(self, user_id: str) -> dict[str, Any] | None: """Return the subscription record for ``user_id``, or ``None`` if absent.""" return self._subscriptions.get(user_id) def cancel_subscription(self, user_id: str) -> None: """Cancel the user's Stripe subscription and downgrade them to free. Raises ``HTTP 404`` when no active subscription exists. """ sub = self._subscriptions.get(user_id) if sub is None or not sub.get("stripe_subscription_id"): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="No active subscription found", ) if self._configured(): s = self._client() s.Subscription.cancel(sub["stripe_subscription_id"]) self._subscriptions[user_id] = { **sub, "tier": "free", "status": "canceled", } # Module-level singleton shared across the app. stripe_service = StripeService()