184 lines
6.7 KiB
Python
184 lines
6.7 KiB
Python
"""Stripe service: checkout sessions, webhook handling, subscription management.
|
|
|
|
Subscriptions are stored in-memory until Step 12 migrates them to the
|
|
PostgreSQL ``subscriptions`` table. All Stripe calls are gracefully stubbed
|
|
when ``STRIPE_SECRET_KEY`` is not configured, enabling local development
|
|
without live credentials.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
import stripe as stripe_lib
|
|
from fastapi import HTTPException, status
|
|
|
|
from app.config.settings import settings
|
|
|
|
# Stripe price IDs per tier — replace with real IDs in production .env
|
|
TIER_PRICE_IDS: dict[str, str] = {
|
|
"pro": "price_pro_monthly",
|
|
"power": "price_power_monthly",
|
|
"team": "price_team_monthly",
|
|
}
|
|
|
|
|
|
class StripeService:
|
|
"""Wraps all Stripe interactions and owns the in-memory subscription store.
|
|
|
|
Step 12 will replace ``_subscriptions`` with real PostgreSQL queries.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
# user_id → subscription record dict
|
|
# Replaced by the ``subscriptions`` table in Step 12.
|
|
self._subscriptions: dict[str, dict[str, Any]] = {}
|
|
|
|
# ── Internal helpers ────────────────────────────────────────────────
|
|
|
|
def _configured(self) -> bool:
|
|
return bool(settings.STRIPE_SECRET_KEY)
|
|
|
|
def _client(self) -> Any:
|
|
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
|
return stripe_lib
|
|
|
|
# ── Public API ──────────────────────────────────────────────────────
|
|
|
|
def create_checkout_session(
|
|
self,
|
|
user_id: str,
|
|
tier: str,
|
|
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
|
cancel_url: str = "https://app.adiuva.app/billing/cancel",
|
|
) -> str:
|
|
"""Create a Stripe checkout session and return the URL.
|
|
|
|
Returns a stub URL when Stripe is not configured.
|
|
Raises ``HTTP 400`` for the free tier or an unknown tier.
|
|
"""
|
|
if tier == "free":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Cannot create a checkout session for the free tier",
|
|
)
|
|
|
|
price_id = TIER_PRICE_IDS.get(tier)
|
|
if not price_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Unknown tier: {tier}",
|
|
)
|
|
|
|
if not self._configured():
|
|
return "https://stripe.com/stub-checkout"
|
|
|
|
s = self._client()
|
|
session = s.checkout.Session.create(
|
|
payment_method_types=["card"],
|
|
mode="subscription",
|
|
line_items=[{"price": price_id, "quantity": 1}],
|
|
success_url=success_url,
|
|
cancel_url=cancel_url,
|
|
metadata={"user_id": user_id, "tier": tier},
|
|
)
|
|
return session.url
|
|
|
|
def handle_webhook(self, payload: bytes, sig_header: str) -> None:
|
|
"""Process a Stripe webhook event.
|
|
|
|
Verifies the signature, then dispatches on event type.
|
|
Raises ``HTTP 400`` on signature mismatch.
|
|
No-ops when Stripe is not configured.
|
|
"""
|
|
if not self._configured():
|
|
return
|
|
|
|
try:
|
|
s = self._client()
|
|
event = s.Webhook.construct_event(
|
|
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
|
|
)
|
|
except stripe_lib.error.SignatureVerificationError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid Stripe signature",
|
|
)
|
|
|
|
event_type: str = event["type"]
|
|
data: dict[str, Any] = event["data"]["object"]
|
|
|
|
if event_type == "checkout.session.completed":
|
|
user_id = data.get("metadata", {}).get("user_id")
|
|
tier = data.get("metadata", {}).get("tier", "free")
|
|
sub_id = data.get("subscription")
|
|
period_end = data.get("current_period_end")
|
|
if user_id:
|
|
self._subscriptions[user_id] = {
|
|
"tier": tier,
|
|
"stripe_subscription_id": sub_id,
|
|
"status": "active",
|
|
"current_period_end": period_end,
|
|
}
|
|
|
|
elif event_type == "customer.subscription.updated":
|
|
# TODO(Step12): look up user_id from stripe_customer_id in DB, update tier
|
|
sub_id = data.get("id")
|
|
new_status = data.get("status")
|
|
period_end = data.get("current_period_end")
|
|
for record in self._subscriptions.values():
|
|
if record.get("stripe_subscription_id") == sub_id:
|
|
record["status"] = new_status
|
|
record["current_period_end"] = period_end
|
|
break
|
|
|
|
elif event_type == "customer.subscription.deleted":
|
|
# TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free
|
|
sub_id = data.get("id")
|
|
for user_id, record in self._subscriptions.items():
|
|
if record.get("stripe_subscription_id") == sub_id:
|
|
self._subscriptions[user_id] = {
|
|
**record,
|
|
"tier": "free",
|
|
"status": "canceled",
|
|
}
|
|
break
|
|
|
|
elif event_type == "invoice.payment_failed":
|
|
# TODO(Step12): flag subscription as past_due, notify user
|
|
sub_id = data.get("subscription")
|
|
for record in self._subscriptions.values():
|
|
if record.get("stripe_subscription_id") == sub_id:
|
|
record["status"] = "past_due"
|
|
break
|
|
|
|
def get_subscription(self, user_id: str) -> dict[str, Any] | None:
|
|
"""Return the subscription record for ``user_id``, or ``None`` if absent."""
|
|
return self._subscriptions.get(user_id)
|
|
|
|
def cancel_subscription(self, user_id: str) -> None:
|
|
"""Cancel the user's Stripe subscription and downgrade them to free.
|
|
|
|
Raises ``HTTP 404`` when no active subscription exists.
|
|
"""
|
|
sub = self._subscriptions.get(user_id)
|
|
if sub is None or not sub.get("stripe_subscription_id"):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="No active subscription found",
|
|
)
|
|
|
|
if self._configured():
|
|
s = self._client()
|
|
s.Subscription.cancel(sub["stripe_subscription_id"])
|
|
|
|
self._subscriptions[user_id] = {
|
|
**sub,
|
|
"tier": "free",
|
|
"status": "canceled",
|
|
}
|
|
|
|
|
|
# Module-level singleton shared across the app.
|
|
stripe_service = StripeService()
|