diff --git a/services/billing/Dockerfile b/services/billing/Dockerfile new file mode 100644 index 0000000..ab8d1cc --- /dev/null +++ b/services/billing/Dockerfile @@ -0,0 +1,36 @@ +# ── builder ────────────────────────────────────────────────────────────────── +FROM python:3.12-slim AS builder + +WORKDIR /build + +COPY services/billing/requirements.txt ./requirements.txt +RUN pip install --upgrade pip && \ + pip install --no-cache-dir --prefix=/install -r requirements.txt + +# ── runtime ────────────────────────────────────────────────────────────────── +FROM python:3.12-slim AS runtime + +RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser + +WORKDIR /app + +COPY --from=builder /install /usr/local + +# Shared module +COPY shared/ shared/ + +# Service source +COPY services/billing/app/ app/ + +RUN chown -R appuser:appgroup /app + +USER appuser + +EXPOSE 8000 + +# Billing is lightweight — single worker is fine +CMD ["gunicorn", "app.main:app", \ + "-k", "uvicorn.workers.UvicornWorker", \ + "--bind", "0.0.0.0:8000", \ + "--workers", "1", \ + "--timeout", "30"] diff --git a/services/billing/app/main.py b/services/billing/app/main.py new file mode 100644 index 0000000..41debb3 --- /dev/null +++ b/services/billing/app/main.py @@ -0,0 +1,46 @@ +"""Billing Service — FastAPI application. + +Owns: Stripe checkout/webhook, subscription management, tier feature matrix, +quota enforcement. + +Downstream services query this service (or read the user's tier from +the X-User-Tier header injected by Traefik) for billing decisions. +The webhook endpoint is exposed WITHOUT ForwardAuth so Stripe can reach it. +""" + +from __future__ import annotations + +import logging +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from app.routes import router + +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + logger.info("billing: service started") + yield + logger.info("billing: service stopped") + + +app = FastAPI(title="Adiuva Billing Service", lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["GET", "POST", "DELETE"], + allow_headers=["*"], +) + +app.include_router(router) + + +@app.get("/health") +async def health() -> dict[str, str]: + return {"status": "ok", "service": "billing"} diff --git a/services/billing/app/routes.py b/services/billing/app/routes.py new file mode 100644 index 0000000..df8b5e3 --- /dev/null +++ b/services/billing/app/routes.py @@ -0,0 +1,134 @@ +"""Billing routes: Stripe checkout, webhook, subscription, tier query. + +Adapted for the Billing microservice: + - Authenticated routes use Traefik-injected headers (X-User-Id, X-User-Tier) + - Webhook route has NO auth (Stripe signature verification only) + - Added /tier/{user_id} for internal service-to-service tier lookups + - Added /features/{tier} for feature matrix queries +""" + +from __future__ import annotations + +from typing import Any + +from fastapi import APIRouter, Header, HTTPException, Request, status +from pydantic import BaseModel + +from shared.db import async_session +from shared.schemas import BillingTier + +from app.stripe_service import stripe_service +from app.tier_manager import tier_manager, FEATURES, RATE_LIMITS + +router = APIRouter(prefix="/billing", tags=["billing"]) + + +# ── Request bodies ───────────────────────────────────────────────────── + +class _CheckoutRequest(BaseModel): + tier: BillingTier + + +# ── Checkout ─────────────────────────────────────────────────────────── + +@router.post("/checkout") +async def create_checkout( + body: _CheckoutRequest, + x_user_id: str = Header(..., alias="X-User-Id"), +) -> dict[str, str]: + """Create a Stripe checkout session for a tier upgrade.""" + url = stripe_service.create_checkout_session(x_user_id, body.tier) + return {"checkout_url": url} + + +# ── Webhook (NO auth — Stripe signature only) ───────────────────────── + +@router.post("/webhook") +async def stripe_webhook( + request: Request, + stripe_signature: str = Header(default="", alias="Stripe-Signature"), +) -> dict[str, bool]: + """Handle Stripe webhook events. + + This endpoint is exposed without ForwardAuth in Traefik config + so Stripe can reach it directly. + """ + payload = await request.body() + async with async_session() as db: + await stripe_service.handle_webhook(payload, stripe_signature, db) + return {"ok": True} + + +# ── Subscription CRUD ───────────────────────────────────────────────── + +@router.get("/subscription") +async def get_subscription( + x_user_id: str = Header(..., alias="X-User-Id"), + x_user_tier: str = Header("free", alias="X-User-Tier"), +) -> dict[str, Any]: + """Return the current subscription info for the authenticated user.""" + async with async_session() as db: + sub = await stripe_service.get_subscription(x_user_id, db) + if sub is None: + return { + "tier": x_user_tier, + "status": "free", + "stripe_subscription_id": None, + "current_period_end": None, + } + return sub + + +@router.delete("/subscription") +async def cancel_subscription( + x_user_id: str = Header(..., alias="X-User-Id"), +) -> dict[str, bool]: + """Cancel the active subscription.""" + async with async_session() as db: + await stripe_service.cancel_subscription(x_user_id, db) + return {"ok": True} + + +# ── Tier query (internal, service-to-service) ───────────────────────── + +@router.get("/tier/{user_id}") +async def get_user_tier(user_id: str) -> dict[str, str]: + """Return the billing tier for a given user_id. + + Used by other services for tier lookups. Protected by Traefik + ForwardAuth — only internal services should call this. + """ + async with async_session() as db: + tier = await tier_manager.get_tier(user_id, db) + return {"user_id": user_id, "tier": tier} + + +# ── Feature matrix (public, cacheable) ──────────────────────────────── + +@router.get("/features/{tier}") +async def get_tier_features(tier: str) -> dict[str, Any]: + """Return the feature matrix for a tier.""" + if tier not in FEATURES: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Unknown tier: {tier}", + ) + return { + "tier": tier, + "features": FEATURES[tier], + "rate_limit_rpm": RATE_LIMITS.get(tier, RATE_LIMITS["free"]), + } + + +@router.get("/features") +async def get_all_features() -> dict[str, Any]: + """Return the full feature matrix for all tiers.""" + return { + "tiers": { + tier: { + "features": features, + "rate_limit_rpm": RATE_LIMITS.get(tier, RATE_LIMITS["free"]), + } + for tier, features in FEATURES.items() + }, + } diff --git a/services/billing/app/stripe_service.py b/services/billing/app/stripe_service.py new file mode 100644 index 0000000..9906bd7 --- /dev/null +++ b/services/billing/app/stripe_service.py @@ -0,0 +1,240 @@ +"""Stripe service: checkout sessions, webhook handling, subscription management. + +Adapted for the Billing microservice — uses shared.models and shared.db. +All Stripe calls are gracefully stubbed when STRIPE_SECRET_KEY is not +configured, enabling local development without live credentials. +""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +import stripe as stripe_lib +from fastapi import HTTPException, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from shared.config import settings +from shared.models import Subscription + +# Stripe price IDs per tier — replace with real IDs in production .env +TIER_PRICE_IDS: dict[str, str] = { + "pro": "price_pro_monthly", + "power": "price_power_monthly", + "team": "price_team_monthly", +} + + +class StripeService: + """Wraps all Stripe interactions and owns subscription persistence.""" + + # ── Internal helpers ──────────────────────────────────────────────── + + def _configured(self) -> bool: + return bool(settings.STRIPE_SECRET_KEY) + + def _client(self) -> Any: + stripe_lib.api_key = settings.STRIPE_SECRET_KEY + return stripe_lib + + # ── Public API ────────────────────────────────────────────────────── + + def create_checkout_session( + self, + user_id: str, + tier: str, + success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}", + cancel_url: str = "https://app.adiuva.app/billing/cancel", + ) -> str: + """Create a Stripe checkout session and return the URL.""" + if tier == "free": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot create a checkout session for the free tier", + ) + + price_id = TIER_PRICE_IDS.get(tier) + if not price_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unknown tier: {tier}", + ) + + if not self._configured(): + return "https://stripe.com/stub-checkout" + + s = self._client() + session = s.checkout.Session.create( + payment_method_types=["card"], + mode="subscription", + line_items=[{"price": price_id, "quantity": 1}], + success_url=success_url, + cancel_url=cancel_url, + metadata={"user_id": user_id, "tier": tier}, + ) + return session.url + + async def handle_webhook( + self, + payload: bytes, + sig_header: str, + db: AsyncSession, + ) -> None: + """Process a Stripe webhook event. + + Verifies the signature, then dispatches on event type. + """ + if not self._configured(): + return + + try: + s = self._client() + event = s.Webhook.construct_event( + payload, sig_header, settings.STRIPE_WEBHOOK_SECRET + ) + except stripe_lib.error.SignatureVerificationError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid Stripe signature", + ) + + event_type: str = event["type"] + data: dict[str, Any] = event["data"]["object"] + + if event_type == "checkout.session.completed": + user_id = data.get("metadata", {}).get("user_id") + tier = data.get("metadata", {}).get("tier", "free") + sub_id = data.get("subscription") + period_end_ts = data.get("current_period_end") + period_end = ( + datetime.fromtimestamp(period_end_ts, tz=timezone.utc) + if period_end_ts + else None + ) + if user_id: + await self._upsert_subscription( + db, user_id, sub_id, tier, "active", period_end + ) + + elif event_type == "customer.subscription.updated": + sub_id = data.get("id") + new_status = data.get("status", "active") + period_end_ts = data.get("current_period_end") + period_end = ( + datetime.fromtimestamp(period_end_ts, tz=timezone.utc) + if period_end_ts + else None + ) + if sub_id: + await self._update_subscription_by_stripe_id( + db, sub_id, status=new_status, current_period_end=period_end + ) + + elif event_type == "customer.subscription.deleted": + sub_id = data.get("id") + if sub_id: + await self._update_subscription_by_stripe_id( + db, sub_id, tier="free", status="canceled" + ) + + elif event_type == "invoice.payment_failed": + sub_id = data.get("subscription") + if sub_id: + await self._update_subscription_by_stripe_id( + db, sub_id, status="past_due" + ) + + await db.commit() + + async def get_subscription( + self, user_id: str, db: AsyncSession + ) -> dict[str, Any] | None: + """Return the subscription record for user_id, or None.""" + result = await db.execute( + select(Subscription).where(Subscription.user_id == user_id) + ) + sub = result.scalar_one_or_none() + if sub is None: + return None + return { + "tier": sub.tier, + "stripe_subscription_id": sub.stripe_subscription_id, + "status": sub.status, + "current_period_end": ( + int(sub.current_period_end.timestamp() * 1000) + if sub.current_period_end + else None + ), + } + + async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None: + """Cancel the user's Stripe subscription and downgrade to free.""" + result = await db.execute( + select(Subscription).where(Subscription.user_id == user_id) + ) + sub = result.scalar_one_or_none() + if sub is None or not sub.stripe_subscription_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No active subscription found", + ) + + if self._configured(): + s = self._client() + s.Subscription.cancel(sub.stripe_subscription_id) + + sub.tier = "free" + sub.status = "canceled" + await db.commit() + + # ── Private DB helpers ─────────────────────────────────────────────── + + async def _upsert_subscription( + self, + db: AsyncSession, + user_id: str, + stripe_subscription_id: str | None, + tier: str, + sub_status: str, + current_period_end: datetime | None, + ) -> None: + result = await db.execute( + select(Subscription).where(Subscription.user_id == user_id) + ) + sub = result.scalar_one_or_none() + if sub is None: + sub = Subscription(user_id=user_id) + db.add(sub) + sub.stripe_subscription_id = stripe_subscription_id + sub.tier = tier + sub.status = sub_status + sub.current_period_end = current_period_end + + async def _update_subscription_by_stripe_id( + self, + db: AsyncSession, + stripe_subscription_id: str, + *, + tier: str | None = None, + status: str | None = None, + current_period_end: datetime | None = None, + ) -> None: + result = await db.execute( + select(Subscription).where( + Subscription.stripe_subscription_id == stripe_subscription_id + ) + ) + sub = result.scalar_one_or_none() + if sub is None: + return + if tier is not None: + sub.tier = tier + if status is not None: + sub.status = status + if current_period_end is not None: + sub.current_period_end = current_period_end + + +# Module-level singleton +stripe_service = StripeService() diff --git a/services/billing/app/tier_manager.py b/services/billing/app/tier_manager.py new file mode 100644 index 0000000..9202ec7 --- /dev/null +++ b/services/billing/app/tier_manager.py @@ -0,0 +1,178 @@ +"""Tier manager: feature matrix and quota enforcement. + +Single source of truth for what each billing tier allows. +Other services can query the /tier/{user_id} endpoint or rely on the +X-User-Tier header injected by Traefik. +""" + +from __future__ import annotations + +from typing import Any + +from fastapi import HTTPException, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from shared.config import settings +from shared.models import Subscription +from shared.schemas import BillingTier + +# Feature matrix per tier. -1 means unlimited; 0 means disabled. +FEATURES: dict[str, dict[str, Any]] = { + "free": { + "agents": 3, + "batch_active": 2, + "batch_runs_per_day": 5, + "cloud_storage_gb": 0, + "backup_gb": 0, + "providers": 1, + "batch_builder": False, + "plugin_marketplace": False, + "sso": False, + }, + "pro": { + "agents": -1, + "batch_active": 10, + "batch_runs_per_day": 50, + "cloud_storage_gb": 5, + "backup_gb": 5, + "providers": -1, + "batch_builder": False, + "plugin_marketplace": False, + "sso": False, + }, + "power": { + "agents": -1, + "batch_active": -1, + "batch_runs_per_day": -1, + "cloud_storage_gb": 25, + "backup_gb": 25, + "providers": -1, + "batch_builder": True, + "plugin_marketplace": True, + "sso": False, + }, + "team": { + "agents": -1, + "batch_active": -1, + "batch_runs_per_day": -1, + "cloud_storage_gb": -1, + "backup_gb": -1, + "providers": -1, + "batch_builder": True, + "plugin_marketplace": True, + "sso": True, + }, +} + +# Requests-per-minute limit per tier. +RATE_LIMITS: dict[str, int] = { + "free": 20, + "pro": 60, + "power": 120, + "team": 200, +} + + +class TierManager: + """Centralises tier feature-gating, rate-limit lookups, and quota checks.""" + + async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier: + """Return the current billing tier for user_id from the DB.""" + result = await db.execute( + select(Subscription.tier).where(Subscription.user_id == user_id) + ) + tier: str | None = result.scalar_one_or_none() + if tier is None or tier not in FEATURES: + return "power" if settings.ENV == "dev" else "free" + return tier # type: ignore[return-value] + + def get_features(self, tier: BillingTier) -> dict[str, Any]: + """Return the full feature dict for a tier.""" + return FEATURES.get(tier, FEATURES["free"]) + + def check_feature(self, tier: BillingTier, feature: str) -> bool: + """Return True if tier has feature enabled.""" + value = FEATURES.get(tier, FEATURES["free"]).get(feature) + if value is None: + return False + if isinstance(value, bool): + return value + return value != 0 + + def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None: + """Raise HTTP 403 if tier does not have feature.""" + if not self.check_feature(tier, feature): + detail = ( + f"Feature '{feature}' requires {tier_name} tier or above." + if tier_name + else f"Feature '{feature}' is not available on your current tier." + ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail) + + def get_rate_limit(self, tier: BillingTier) -> int: + """Return the requests-per-minute limit for tier.""" + return RATE_LIMITS.get(tier, RATE_LIMITS["free"]) + + def enforce_quota( + self, + tier: BillingTier, + current_bytes: int = 0, + additional_bytes: int = 0, + ) -> None: + """Raise HTTP 402 if the user would exceed their cloud storage quota.""" + limit_gb: int = FEATURES[tier]["cloud_storage_gb"] + if limit_gb == 0: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Cloud storage is not available on the '{tier}' tier", + ) + if limit_gb == -1: + return + limit_bytes = limit_gb * 1024 ** 3 + if current_bytes + additional_bytes > limit_bytes: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Storage quota exceeded for tier '{tier}'", + ) + + def enforce_backup_quota( + self, + tier: BillingTier, + current_bytes: int = 0, + additional_bytes: int = 0, + ) -> None: + """Raise HTTP 402 if the user would exceed their backup quota.""" + limit_gb: int = FEATURES[tier]["backup_gb"] + if limit_gb == 0: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Backup is not available on the '{tier}' tier", + ) + if limit_gb == -1: + return + limit_bytes = limit_gb * 1024 ** 3 + if current_bytes + additional_bytes > limit_bytes: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Backup quota exceeded for tier '{tier}'", + ) + + def check_quota( + self, + tier: BillingTier, + current_bytes: int = 0, + additional_bytes: int = 0, + ) -> bool: + """Return True if the user can store additional_bytes more data.""" + limit_gb: int = FEATURES[tier]["cloud_storage_gb"] + if limit_gb == 0: + return False + if limit_gb == -1: + return True + limit_bytes = limit_gb * 1024 ** 3 + return current_bytes + additional_bytes <= limit_bytes + + +# Module-level singleton +tier_manager = TierManager() diff --git a/services/billing/requirements.txt b/services/billing/requirements.txt new file mode 100644 index 0000000..9b220dc --- /dev/null +++ b/services/billing/requirements.txt @@ -0,0 +1,9 @@ +fastapi>=0.115.0 +uvicorn[standard]>=0.34.0 +gunicorn>=22.0.0 +pydantic>=2.10.0 +pydantic-settings>=2.7.0 +sqlalchemy>=2.0.0 +asyncpg>=0.30.0 +python-dotenv>=1.0.0 +stripe>=8.0.0