From 5d485b3665e6c74649eb11a8c5fc02bc6781f9a3 Mon Sep 17 00:00:00 2001 From: roberto Date: Tue, 3 Mar 2026 12:39:32 +0100 Subject: [PATCH] step 12 --- alembic.ini | 47 +++++ alembic/env.py | 93 +++++++++ alembic/script.py.mako | 28 +++ alembic/versions/001_initial_schema.py | 202 +++++++++++++++++++ app/api/middleware/auth.py | 24 ++- app/api/routes/auth.py | 159 +++++++++++---- app/api/routes/billing.py | 11 +- app/billing/stripe_service.py | 181 ++++++++++++----- app/billing/tier_manager.py | 106 ++++------ app/db.py | 40 ++++ app/main.py | 4 +- app/models.py | 269 +++++++++++++++++++++++++ 12 files changed, 999 insertions(+), 165 deletions(-) create mode 100644 alembic.ini create mode 100644 alembic/env.py create mode 100644 alembic/script.py.mako create mode 100644 alembic/versions/001_initial_schema.py create mode 100644 app/db.py create mode 100644 app/models.py diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..1223deb --- /dev/null +++ b/alembic.ini @@ -0,0 +1,47 @@ +# Alembic configuration file. +# The async app uses postgresql+asyncpg:// at runtime. +# Alembic CLI uses the sync psycopg2 URL set in env.py (reads from DATABASE_URL env var). + +[alembic] +script_location = alembic +prepend_sys_path = . +version_path_separator = os + +# sqlalchemy.url is overridden in alembic/env.py — leave as placeholder. +sqlalchemy.url = driver://user:pass@localhost/dbname + +[post_write_hooks] + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..23dac6c --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,93 @@ +"""Alembic migration environment — async-compatible. + +At runtime the app uses ``postgresql+asyncpg://``. Alembic's CLI is +synchronous, so we derive a *sync* psycopg2 URL from the same DATABASE_URL +env var by replacing the driver prefix. + +Run migrations with: + alembic upgrade head +""" + +from __future__ import annotations + +import asyncio +import os +import re +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import engine_from_config, pool +from sqlalchemy.ext.asyncio import create_async_engine + +# Alembic Config object (gives access to alembic.ini values). +config = context.config + +# Set up Python logging from alembic.ini. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# Import the Base so that Alembic can detect model changes for --autogenerate. +from app.models import Base # noqa: E402 + +target_metadata = Base.metadata + + +def _sync_url(async_url: str) -> str: + """Convert an asyncpg URL to a psycopg2 URL for Alembic CLI.""" + return re.sub(r"postgresql\+asyncpg", "postgresql+psycopg2", async_url) + + +def _get_url() -> str: + db_url = os.environ.get("DATABASE_URL", "") + if not db_url: + # Fall back to settings if env var not set directly. + from app.config.settings import settings # noqa: PLC0415 + db_url = settings.DATABASE_URL + return _sync_url(db_url) + + +def run_migrations_offline() -> None: + """Emit SQL without a live DB connection.""" + url = _get_url() + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + compare_type=True, + ) + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection): # type: ignore[no-untyped-def] + context.configure( + connection=connection, + target_metadata=target_metadata, + compare_type=True, + ) + with context.begin_transaction(): + context.run_migrations() + + +async def run_migrations_online_async() -> None: + """Run migrations against a live DB using the async engine.""" + async_url = os.environ.get("DATABASE_URL", "") + if not async_url: + from app.config.settings import settings # noqa: PLC0415 + async_url = settings.DATABASE_URL + + connectable = create_async_engine(async_url, poolclass=pool.NullPool) + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + await connectable.dispose() + + +def run_migrations_online() -> None: + asyncio.run(run_migrations_online_async()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..ee746cf --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from __future__ import annotations + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/001_initial_schema.py b/alembic/versions/001_initial_schema.py new file mode 100644 index 0000000..abe611a --- /dev/null +++ b/alembic/versions/001_initial_schema.py @@ -0,0 +1,202 @@ +"""Initial schema: users, refresh_tokens, subscriptions, storage_records, +backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events. + +Revision ID: 001 +Revises: +Create Date: 2026-03-02 +""" + +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision: str = "001" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ── Enum types ──────────────────────────────────────────────────────── + billing_tier = postgresql.ENUM( + "free", "pro", "power", "team", name="billing_tier", create_type=False + ) + plugin_status = postgresql.ENUM( + "pending_review", "approved", "rejected", name="plugin_status", create_type=False + ) + review_decision = postgresql.ENUM( + "approved", "rejected", name="review_decision", create_type=False + ) + for enum in (billing_tier, plugin_status, review_decision): + enum.create(op.get_bind(), checkfirst=True) + + # ── users ───────────────────────────────────────────────────────────── + op.create_table( + "users", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("email", sa.String(255), nullable=False), + sa.Column("password_hash", sa.String(255), nullable=False), + sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier"), nullable=False, server_default="free"), + sa.Column("stripe_customer_id", sa.String(255), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("email"), + ) + op.create_index("ix_users_email", "users", ["email"]) + + # ── refresh_tokens ──────────────────────────────────────────────────── + op.create_table( + "refresh_tokens", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("token_hash", sa.String(64), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.UniqueConstraint("token_hash"), + ) + op.create_index("ix_refresh_tokens_user_id", "refresh_tokens", ["user_id"]) + op.create_index("ix_refresh_tokens_token_hash", "refresh_tokens", ["token_hash"]) + + # ── subscriptions ───────────────────────────────────────────────────── + op.create_table( + "subscriptions", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("stripe_subscription_id", sa.String(255), nullable=True), + sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier"), nullable=False, server_default="free"), + sa.Column("status", sa.String(50), nullable=False, server_default="free"), + sa.Column("current_period_end", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.UniqueConstraint("user_id"), + ) + op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"]) + op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"]) + + # ── storage_records ─────────────────────────────────────────────────── + op.create_table( + "storage_records", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("table_name", sa.String(100), nullable=False), + sa.Column("s3_key", sa.String(500), nullable=False), + sa.Column("checksum", sa.String(64), nullable=False), + sa.Column("size_bytes", sa.Integer, nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + ) + op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"]) + + # ── backup_metadata ─────────────────────────────────────────────────── + op.create_table( + "backup_metadata", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("s3_key", sa.String(500), nullable=False), + sa.Column("version", sa.Integer, nullable=False), + sa.Column("timestamp", sa.BigInteger, nullable=False), + sa.Column("checksum", sa.String(64), nullable=False), + sa.Column("size_bytes", sa.Integer, nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + ) + op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"]) + + # ── plugins ─────────────────────────────────────────────────────────── + op.create_table( + "plugins", + sa.Column("id", sa.String(255), nullable=False), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("description", sa.Text, nullable=False, server_default=""), + sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"), + sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True), + sa.Column("author_name", sa.String(255), nullable=False, server_default=""), + sa.Column("category", sa.String(100), nullable=False, server_default=""), + sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"), + sa.Column("permissions", sa.Text, nullable=False, server_default="[]"), + sa.Column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status"), nullable=False, server_default="pending_review"), + sa.Column("s3_package_key", sa.String(500), nullable=True), + sa.Column("install_count", sa.Integer, nullable=False, server_default="0"), + sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"), + sa.Column("rejection_reason", sa.Text, nullable=True), + sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"), + ) + + # ── plugin_installations ────────────────────────────────────────────── + op.create_table( + "plugin_installations", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("plugin_id", sa.String(255), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"), + ) + op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"]) + op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"]) + + # ── plugin_reviews ──────────────────────────────────────────────────── + op.create_table( + "plugin_reviews", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("plugin_id", sa.String(255), nullable=False), + sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True), + sa.Column("decision", sa.Enum("approved", "rejected", name="review_decision"), nullable=False), + sa.Column("notes", sa.Text, nullable=True), + sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"), + ) + op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"]) + + # ── revenue_events ──────────────────────────────────────────────────── + op.create_table( + "revenue_events", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("plugin_id", sa.String(255), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"), + sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"), + sa.Column("stripe_transfer_id", sa.String(255), nullable=True), + sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + ) + op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"]) + op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"]) + + +def downgrade() -> None: + op.drop_table("revenue_events") + op.drop_table("plugin_reviews") + op.drop_table("plugin_installations") + op.drop_table("plugins") + op.drop_table("backup_metadata") + op.drop_table("storage_records") + op.drop_table("subscriptions") + op.drop_table("refresh_tokens") + op.drop_table("users") + + op.execute("DROP TYPE IF EXISTS review_decision") + op.execute("DROP TYPE IF EXISTS plugin_status") + op.execute("DROP TYPE IF EXISTS billing_tier") diff --git a/app/api/middleware/auth.py b/app/api/middleware/auth.py index b596121..1cd8df0 100644 --- a/app/api/middleware/auth.py +++ b/app/api/middleware/auth.py @@ -1,8 +1,9 @@ """Auth middleware — JWT validation dependency. ``get_current_user`` is the FastAPI dependency used by all protected routes. -It decodes the Bearer JWT, validates signature and expiry, and returns a -``UserProfile`` carrying ``id``, ``email``, and ``tier``. +It decodes the Bearer JWT (identity + expiry), then fetches the current tier +from the ``subscriptions`` table so that tier changes take effect immediately +without requiring token re-issue. Exempt routes (no JWT required): - POST /api/v1/auth/register @@ -15,8 +16,11 @@ from __future__ import annotations from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from jose import JWTError, jwt +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from app.config.settings import settings +from app.db import get_session from app.schemas import UserProfile oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") @@ -24,12 +28,15 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") async def get_current_user( token: str = Depends(oauth2_scheme), + db: AsyncSession = Depends(get_session), ) -> UserProfile: """Validate a Bearer JWT and return the authenticated user. + The JWT is used for identity and expiry only. The tier is fetched live + from the ``subscriptions`` table so that upgrades/downgrades take effect + immediately. Falls back to ``'free'`` when no subscription row exists. + Raises HTTP 401 on any invalid or expired token. - The tier embedded in the JWT is used for feature-gating until Step 12 - adds a live DB lookup. """ credentials_exc = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -42,10 +49,17 @@ async def get_current_user( ) user_id: str | None = payload.get("sub") email: str | None = payload.get("email") - tier: str = payload.get("tier", "free") if not user_id or not email: raise credentials_exc except JWTError: raise credentials_exc + # Live tier lookup — subscription row is the authoritative source. + from app.models import Subscription # noqa: PLC0415 + + result = await db.execute( + select(Subscription.tier).where(Subscription.user_id == user_id) + ) + tier: str = result.scalar_one_or_none() or "free" + return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type] diff --git a/app/api/routes/auth.py b/app/api/routes/auth.py index 64c0bf5..0fb3046 100644 --- a/app/api/routes/auth.py +++ b/app/api/routes/auth.py @@ -1,33 +1,36 @@ """Auth routes: register, login, refresh, me. -Users and refresh tokens are kept in an in-memory dict until Step 12 -migrates them to PostgreSQL. +Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens +tables). Passwords are hashed with bcrypt; refresh tokens are stored as +SHA-256 hashes so plaintext never reaches the DB. """ from __future__ import annotations +import hashlib import time import uuid -from typing import Any +from datetime import datetime, timedelta, timezone import bcrypt from fastapi import APIRouter, Depends, HTTPException, status from jose import jwt from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user from app.config.settings import settings +from app.db import get_session +from app.models import RefreshToken, User from app.schemas import AuthTokens, UserProfile router = APIRouter(prefix="/auth", tags=["auth"]) -# ── In-memory stores (replaced by PostgreSQL in Step 12) ───────────── -_users: dict[str, dict[str, Any]] = {} # email → user record -_refresh_tokens: dict[str, str] = {} # plain token → user_id - # ── Internal helpers ───────────────────────────────────────────────── + def _hash_password(password: str) -> str: return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() @@ -36,30 +39,29 @@ def _verify_password(password: str, hashed: str) -> bool: return bcrypt.checkpw(password.encode(), hashed.encode()) -def _make_tokens(user_id: str, email: str, tier: str) -> AuthTokens: +def _hash_token(plain_token: str) -> str: + """SHA-256 of the plain refresh token string.""" + return hashlib.sha256(plain_token.encode()).hexdigest() + + +def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]: + """Return (signed JWT, expires_at_ms).""" now = int(time.time()) - access_exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60 - access_payload = { + exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60 + payload = { "sub": user_id, "email": email, "tier": tier, - "exp": access_exp, + "exp": exp, "iat": now, } - access_token = jwt.encode( - access_payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM - ) - refresh_token = str(uuid.uuid4()) - _refresh_tokens[refresh_token] = user_id - return AuthTokens( - access_token=access_token, - refresh_token=refresh_token, - expires_at=access_exp * 1000, # milliseconds for client - ) + token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM) + return token, exp * 1000 # ms for client # ── Request bodies ──────────────────────────────────────────────────── + class _RegisterRequest(BaseModel): email: str password: str @@ -76,40 +78,117 @@ class _RefreshRequest(BaseModel): # ── Routes ──────────────────────────────────────────────────────────── + @router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED) -async def register(body: _RegisterRequest) -> AuthTokens: +async def register( + body: _RegisterRequest, + db: AsyncSession = Depends(get_session), +) -> AuthTokens: """Create a new account and return JWT tokens.""" - if body.email in _users: + existing = await db.execute(select(User).where(User.email == body.email)) + if existing.scalar_one_or_none() is not None: raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered") - user_id = str(uuid.uuid4()) - _users[body.email] = { - "id": user_id, - "email": body.email, - "password_hash": _hash_password(body.password), - "tier": "free", - } - return _make_tokens(user_id, body.email, "free") + + user = User( + id=str(uuid.uuid4()), + email=body.email, + password_hash=_hash_password(body.password), + tier="free", + ) + db.add(user) + await db.flush() # get user.id without committing + + plain_token = str(uuid.uuid4()) + expires_at = datetime.now(timezone.utc) + timedelta( + days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS + ) + rt = RefreshToken( + user_id=user.id, + token_hash=_hash_token(plain_token), + expires_at=expires_at, + ) + db.add(rt) + await db.commit() + + access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier) + return AuthTokens( + access_token=access_token, + refresh_token=plain_token, + expires_at=expires_at_ms, + ) @router.post("/login", response_model=AuthTokens) -async def login(body: _LoginRequest) -> AuthTokens: +async def login( + body: _LoginRequest, + db: AsyncSession = Depends(get_session), +) -> AuthTokens: """Validate credentials and return JWT tokens.""" - user = _users.get(body.email) - if not user or not _verify_password(body.password, user["password_hash"]): + result = await db.execute(select(User).where(User.email == body.email)) + user = result.scalar_one_or_none() + if user is None or not _verify_password(body.password, user.password_hash): raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials") - return _make_tokens(user["id"], user["email"], user["tier"]) + + plain_token = str(uuid.uuid4()) + expires_at = datetime.now(timezone.utc) + timedelta( + days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS + ) + rt = RefreshToken( + user_id=user.id, + token_hash=_hash_token(plain_token), + expires_at=expires_at, + ) + db.add(rt) + await db.commit() + + access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier) + return AuthTokens( + access_token=access_token, + refresh_token=plain_token, + expires_at=expires_at_ms, + ) @router.post("/refresh", response_model=AuthTokens) -async def refresh(body: _RefreshRequest) -> AuthTokens: +async def refresh( + body: _RefreshRequest, + db: AsyncSession = Depends(get_session), +) -> AuthTokens: """Rotate a refresh token and return a new token pair.""" - user_id = _refresh_tokens.pop(body.refresh_token, None) - if user_id is None: + token_hash = _hash_token(body.refresh_token) + result = await db.execute( + select(RefreshToken).where(RefreshToken.token_hash == token_hash) + ) + rt = result.scalar_one_or_none() + + now = datetime.now(timezone.utc) + if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token") - user = next((u for u in _users.values() if u["id"] == user_id), None) + + # Rotate: delete old token, issue new one. + await db.delete(rt) + + user_result = await db.execute(select(User).where(User.id == rt.user_id)) + user = user_result.scalar_one_or_none() if user is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found") - return _make_tokens(user["id"], user["email"], user["tier"]) + + plain_token = str(uuid.uuid4()) + new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS) + new_rt = RefreshToken( + user_id=user.id, + token_hash=_hash_token(plain_token), + expires_at=new_expires, + ) + db.add(new_rt) + await db.commit() + + access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier) + return AuthTokens( + access_token=access_token, + refresh_token=plain_token, + expires_at=expires_at_ms, + ) @router.get("/me", response_model=UserProfile) diff --git a/app/api/routes/billing.py b/app/api/routes/billing.py index 6ca1aa7..e8bdef2 100644 --- a/app/api/routes/billing.py +++ b/app/api/routes/billing.py @@ -11,9 +11,11 @@ from typing import Any from fastapi import APIRouter, Depends, Header, Request, status from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user from app.billing.stripe_service import stripe_service +from app.db import get_session from app.schemas import BillingTier, UserProfile router = APIRouter(prefix="/billing", tags=["billing"]) @@ -44,6 +46,7 @@ async def create_checkout( async def stripe_webhook( request: Request, stripe_signature: str = Header(default="", alias="Stripe-Signature"), + db: AsyncSession = Depends(get_session), ) -> dict[str, bool]: """Handle Stripe webhook events. @@ -51,16 +54,17 @@ async def stripe_webhook( Returns 200 immediately when Stripe is not configured (local dev). """ payload = await request.body() - stripe_service.handle_webhook(payload, stripe_signature) + await stripe_service.handle_webhook(payload, stripe_signature, db) return {"ok": True} @router.get("/subscription", response_model=dict) async def get_subscription( current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), ) -> dict[str, Any]: """Return the current subscription info for the authenticated user.""" - sub = stripe_service.get_subscription(current_user.id) + sub = await stripe_service.get_subscription(current_user.id, db) if sub is None: return { "tier": current_user.tier, @@ -74,7 +78,8 @@ async def get_subscription( @router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK) async def cancel_subscription( current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), ) -> dict[str, bool]: """Cancel the active subscription.""" - stripe_service.cancel_subscription(current_user.id) + await stripe_service.cancel_subscription(current_user.id, db) return {"ok": True} diff --git a/app/billing/stripe_service.py b/app/billing/stripe_service.py index 0c68ded..3bd9038 100644 --- a/app/billing/stripe_service.py +++ b/app/billing/stripe_service.py @@ -1,17 +1,19 @@ """Stripe service: checkout sessions, webhook handling, subscription management. -Subscriptions are stored in-memory until Step 12 migrates them to the -PostgreSQL ``subscriptions`` table. All Stripe calls are gracefully stubbed -when ``STRIPE_SECRET_KEY`` is not configured, enabling local development -without live credentials. +Subscription records are persisted in the PostgreSQL ``subscriptions`` table. +All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not +configured, enabling local development without live credentials. """ from __future__ import annotations +from datetime import datetime, timezone from typing import Any import stripe as stripe_lib from fastapi import HTTPException, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from app.config.settings import settings @@ -24,15 +26,7 @@ TIER_PRICE_IDS: dict[str, str] = { class StripeService: - """Wraps all Stripe interactions and owns the in-memory subscription store. - - Step 12 will replace ``_subscriptions`` with real PostgreSQL queries. - """ - - def __init__(self) -> None: - # user_id → subscription record dict - # Replaced by the ``subscriptions`` table in Step 12. - self._subscriptions: dict[str, dict[str, Any]] = {} + """Wraps all Stripe interactions and owns subscription persistence.""" # ── Internal helpers ──────────────────────────────────────────────── @@ -84,7 +78,12 @@ class StripeService: ) return session.url - def handle_webhook(self, payload: bytes, sig_header: str) -> None: + async def handle_webhook( + self, + payload: bytes, + sig_header: str, + db: AsyncSession, + ) -> None: """Process a Stripe webhook event. Verifies the signature, then dispatches on event type. @@ -112,57 +111,82 @@ class StripeService: user_id = data.get("metadata", {}).get("user_id") tier = data.get("metadata", {}).get("tier", "free") sub_id = data.get("subscription") - period_end = data.get("current_period_end") + period_end_ts = data.get("current_period_end") + period_end = ( + datetime.fromtimestamp(period_end_ts, tz=timezone.utc) + if period_end_ts + else None + ) if user_id: - self._subscriptions[user_id] = { - "tier": tier, - "stripe_subscription_id": sub_id, - "status": "active", - "current_period_end": period_end, - } + await self._upsert_subscription( + db, user_id, sub_id, tier, "active", period_end + ) elif event_type == "customer.subscription.updated": - # TODO(Step12): look up user_id from stripe_customer_id in DB, update tier sub_id = data.get("id") - new_status = data.get("status") - period_end = data.get("current_period_end") - for record in self._subscriptions.values(): - if record.get("stripe_subscription_id") == sub_id: - record["status"] = new_status - record["current_period_end"] = period_end - break + new_status = data.get("status", "active") + period_end_ts = data.get("current_period_end") + period_end = ( + datetime.fromtimestamp(period_end_ts, tz=timezone.utc) + if period_end_ts + else None + ) + if sub_id: + await self._update_subscription_by_stripe_id( + db, sub_id, status=new_status, current_period_end=period_end + ) elif event_type == "customer.subscription.deleted": - # TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free sub_id = data.get("id") - for user_id, record in self._subscriptions.items(): - if record.get("stripe_subscription_id") == sub_id: - self._subscriptions[user_id] = { - **record, - "tier": "free", - "status": "canceled", - } - break + if sub_id: + await self._update_subscription_by_stripe_id( + db, sub_id, tier="free", status="canceled" + ) elif event_type == "invoice.payment_failed": - # TODO(Step12): flag subscription as past_due, notify user sub_id = data.get("subscription") - for record in self._subscriptions.values(): - if record.get("stripe_subscription_id") == sub_id: - record["status"] = "past_due" - break + if sub_id: + await self._update_subscription_by_stripe_id( + db, sub_id, status="past_due" + ) - def get_subscription(self, user_id: str) -> dict[str, Any] | None: + await db.commit() + + async def get_subscription( + self, user_id: str, db: AsyncSession + ) -> dict[str, Any] | None: """Return the subscription record for ``user_id``, or ``None`` if absent.""" - return self._subscriptions.get(user_id) + from app.models import Subscription # noqa: PLC0415 - def cancel_subscription(self, user_id: str) -> None: + result = await db.execute( + select(Subscription).where(Subscription.user_id == user_id) + ) + sub = result.scalar_one_or_none() + if sub is None: + return None + return { + "tier": sub.tier, + "stripe_subscription_id": sub.stripe_subscription_id, + "status": sub.status, + "current_period_end": ( + int(sub.current_period_end.timestamp() * 1000) + if sub.current_period_end + else None + ), + } + + async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None: """Cancel the user's Stripe subscription and downgrade them to free. Raises ``HTTP 404`` when no active subscription exists. """ - sub = self._subscriptions.get(user_id) - if sub is None or not sub.get("stripe_subscription_id"): + from app.models import Subscription # noqa: PLC0415 + + result = await db.execute( + select(Subscription).where(Subscription.user_id == user_id) + ) + sub = result.scalar_one_or_none() + if sub is None or not sub.stripe_subscription_id: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="No active subscription found", @@ -170,13 +194,62 @@ class StripeService: if self._configured(): s = self._client() - s.Subscription.cancel(sub["stripe_subscription_id"]) + s.Subscription.cancel(sub.stripe_subscription_id) - self._subscriptions[user_id] = { - **sub, - "tier": "free", - "status": "canceled", - } + sub.tier = "free" + sub.status = "canceled" + await db.commit() + + # ── Private DB helpers ─────────────────────────────────────────────── + + async def _upsert_subscription( + self, + db: AsyncSession, + user_id: str, + stripe_subscription_id: str | None, + tier: str, + sub_status: str, + current_period_end: datetime | None, + ) -> None: + from app.models import Subscription # noqa: PLC0415 + + result = await db.execute( + select(Subscription).where(Subscription.user_id == user_id) + ) + sub = result.scalar_one_or_none() + if sub is None: + sub = Subscription(user_id=user_id) + db.add(sub) + sub.stripe_subscription_id = stripe_subscription_id + sub.tier = tier + sub.status = sub_status + sub.current_period_end = current_period_end + + async def _update_subscription_by_stripe_id( + self, + db: AsyncSession, + stripe_subscription_id: str, + *, + tier: str | None = None, + status: str | None = None, + current_period_end: datetime | None = None, + ) -> None: + from app.models import Subscription # noqa: PLC0415 + + result = await db.execute( + select(Subscription).where( + Subscription.stripe_subscription_id == stripe_subscription_id + ) + ) + sub = result.scalar_one_or_none() + if sub is None: + return + if tier is not None: + sub.tier = tier + if status is not None: + sub.status = status + if current_period_end is not None: + sub.current_period_end = current_period_end # Module-level singleton shared across the app. diff --git a/app/billing/tier_manager.py b/app/billing/tier_manager.py index fbd6e5d..254dfd7 100644 --- a/app/billing/tier_manager.py +++ b/app/billing/tier_manager.py @@ -1,8 +1,9 @@ """Tier manager: feature matrix and quota enforcement. ``TierManager`` is the single source of truth for what each billing tier -allows. ``get_tier`` reads from the ``StripeService`` in-memory store until -Step 12 replaces it with a live PostgreSQL lookup. +allows. ``get_tier`` queries the ``subscriptions`` table for the live tier. +Quota-enforcement helpers take ``tier`` directly — the caller already has it +from ``current_user.tier`` (provided by ``get_current_user``). """ from __future__ import annotations @@ -10,6 +11,8 @@ from __future__ import annotations from typing import Any from fastapi import HTTPException, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from app.schemas import BillingTier @@ -67,55 +70,42 @@ RATE_LIMITS: dict[str, int] = { class TierManager: - """Centralises tier feature-gating, rate-limit lookups, and quota checks. - - ``get_tier`` consults the ``StripeService`` singleton. Step 12 will - replace that with a PostgreSQL query so that the tier is always fresh. - """ + """Centralises tier feature-gating, rate-limit lookups, and quota checks.""" # ── Tier lookup ───────────────────────────────────────────────────── - def get_tier(self, user_id: str) -> BillingTier: - """Return the current billing tier for ``user_id``. + async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier: + """Return the current billing tier for ``user_id`` from the DB. - Falls back to ``'free'`` when no subscription record exists. - Step 12 will replace this with a live DB lookup. + Falls back to ``'free'`` when no subscription row exists. """ - # Import here to avoid circular imports at module load time. - from app.billing.stripe_service import stripe_service # noqa: PLC0415 + from app.models import Subscription # noqa: PLC0415 - sub = stripe_service.get_subscription(user_id) - if sub is None: - return "free" - tier = sub.get("tier", "free") - # Validate against known tiers; unknown values fall back to free. - if tier not in FEATURES: + result = await db.execute( + select(Subscription.tier).where(Subscription.user_id == user_id) + ) + tier: str | None = result.scalar_one_or_none() + if tier is None or tier not in FEATURES: return "free" return tier # type: ignore[return-value] # ── Feature access ─────────────────────────────────────────────────── - def check_feature(self, user_id: str, feature: str) -> bool: - """Return ``True`` if ``user_id``'s current tier has ``feature`` enabled. + def check_feature(self, tier: BillingTier, feature: str) -> bool: + """Return ``True`` if ``tier`` has ``feature`` enabled. For numeric features, any value > 0 or -1 (unlimited) counts as enabled. """ - tier = self.get_tier(user_id) - value = FEATURES[tier].get(feature) + value = FEATURES.get(tier, FEATURES["free"]).get(feature) if value is None: return False if isinstance(value, bool): return value - # Numeric: -1 means unlimited (enabled), 0 means disabled. return value != 0 - def require_feature(self, user_id: str, feature: str, tier_name: str = "") -> None: - """Raise ``HTTP 403`` if ``user_id`` does not have ``feature``. - - ``tier_name`` is used in the error message to tell users which tier - they need to upgrade to. - """ - if not self.check_feature(user_id, feature): + def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None: + """Raise ``HTTP 403`` if ``tier`` does not have ``feature``.""" + if not self.check_feature(tier, feature): detail = ( f"Feature '{feature}' requires {tier_name} tier or above." if tier_name @@ -131,39 +121,17 @@ class TierManager: # ── Storage quota ──────────────────────────────────────────────────── - def check_quota( - self, - user_id: str, - current_bytes: int = 0, - additional_bytes: int = 0, - ) -> bool: - """Return ``True`` if ``user_id`` can store ``additional_bytes`` more data. - - ``current_bytes`` is the user's current storage usage (from the - caller's record-keeping). Step 12 will remove these parameters and - query the DB directly. - - Returns ``False`` if the tier has no storage allocation at all - (free tier), or if ``current_bytes + additional_bytes`` would exceed - the tier's ``cloud_storage_gb`` limit. - """ - tier = self.get_tier(user_id) - limit_gb: int = FEATURES[tier]["cloud_storage_gb"] - if limit_gb == 0: - return False # tier has no storage - if limit_gb == -1: - return True # unlimited - limit_bytes = limit_gb * 1024 ** 3 - return current_bytes + additional_bytes <= limit_bytes - def enforce_quota( self, - user_id: str, + tier: BillingTier, current_bytes: int = 0, additional_bytes: int = 0, ) -> None: - """Raise ``HTTP 402`` if ``user_id`` would exceed their storage quota.""" - tier = self.get_tier(user_id) + """Raise ``HTTP 402`` if the user would exceed their cloud storage quota. + + ``tier`` is the caller's current tier (from ``current_user.tier``). + ``current_bytes`` is the total bytes already stored (queried by caller). + """ limit_gb: int = FEATURES[tier]["cloud_storage_gb"] if limit_gb == 0: raise HTTPException( @@ -181,12 +149,11 @@ class TierManager: def enforce_backup_quota( self, - user_id: str, + tier: BillingTier, current_bytes: int = 0, additional_bytes: int = 0, ) -> None: - """Raise ``HTTP 402`` if ``user_id`` would exceed their backup quota.""" - tier = self.get_tier(user_id) + """Raise ``HTTP 402`` if the user would exceed their backup quota.""" limit_gb: int = FEATURES[tier]["backup_gb"] if limit_gb == 0: raise HTTPException( @@ -202,6 +169,21 @@ class TierManager: detail=f"Backup quota exceeded for tier '{tier}'", ) + def check_quota( + self, + tier: BillingTier, + current_bytes: int = 0, + additional_bytes: int = 0, + ) -> bool: + """Return ``True`` if the user can store ``additional_bytes`` more data.""" + limit_gb: int = FEATURES[tier]["cloud_storage_gb"] + if limit_gb == 0: + return False + if limit_gb == -1: + return True + limit_bytes = limit_gb * 1024 ** 3 + return current_bytes + additional_bytes <= limit_bytes + # Module-level singleton shared across the app. tier_manager = TierManager() diff --git a/app/db.py b/app/db.py new file mode 100644 index 0000000..38a8d27 --- /dev/null +++ b/app/db.py @@ -0,0 +1,40 @@ +"""Database engine, session factory, and base model. + +All app code uses the async SQLAlchemy API. Alembic migrations use the +synchronous psycopg2 URL for the CLI (see alembic/env.py). + +Usage in routes: + from app.db import get_session + from sqlalchemy.ext.asyncio import AsyncSession + + async def my_route(db: AsyncSession = Depends(get_session)): + result = await db.execute(select(User).where(User.email == email)) + user = result.scalar_one_or_none() +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import DeclarativeBase + +from app.config.settings import settings + +engine = create_async_engine( + settings.DATABASE_URL, + pool_pre_ping=True, + echo=settings.ENV == "dev", +) + +async_session = async_sessionmaker(engine, expire_on_commit=False) + + +class Base(DeclarativeBase): + """Shared declarative base for all ORM models.""" + + +async def get_session() -> AsyncGenerator[AsyncSession, None]: + """FastAPI dependency that yields an async DB session per request.""" + async with async_session() as session: + yield session diff --git a/app/main.py b/app/main.py index 8db1a20..29d7230 100644 --- a/app/main.py +++ b/app/main.py @@ -16,7 +16,9 @@ async def lifespan(app: FastAPI): yield - # Shutdown: nothing to clean up for now + # Shutdown: dispose SQLAlchemy connection pool + from app.db import engine + await engine.dispose() def create_app() -> FastAPI: diff --git a/app/models.py b/app/models.py new file mode 100644 index 0000000..ee5ba03 --- /dev/null +++ b/app/models.py @@ -0,0 +1,269 @@ +"""SQLAlchemy ORM models for all persistent tables. + +Only auth, billing, storage metadata, and marketplace data live here. +User content (notes, tasks, etc.) is NEVER persisted server-side — +it lives in E2E-encrypted blobs in S3, referenced by storage_records. + +Table inventory: + users — account credentials + tier + refresh_tokens — hashed refresh token store + subscriptions — Stripe subscription records + storage_records — S3 blob metadata (no plaintext) + backup_metadata — encrypted backup manifests + plugins — marketplace plugin catalog + plugin_installations — per-user install records + plugin_reviews — admin review decisions + revenue_events — Stripe Connect 70/30 split ledger +""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timezone + +from sqlalchemy import ( + BigInteger, + Boolean, + DateTime, + Enum, + Float, + ForeignKey, + Integer, + String, + Text, + UniqueConstraint, + func, +) +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from app.db import Base + +# ── Helpers ────────────────────────────────────────────────────────────── + + +def _uuid() -> str: + return str(uuid.uuid4()) + + +def _now() -> datetime: + return datetime.now(timezone.utc) + + +# ── Enum types ──────────────────────────────────────────────────────────── + +TierEnum = Enum("free", "pro", "power", "team", name="billing_tier") +PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status") +ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision") + + +# ── Models ──────────────────────────────────────────────────────────────── + + +class User(Base): + __tablename__ = "users" + + id: Mapped[str] = mapped_column( + UUID(as_uuid=False), primary_key=True, default=_uuid + ) + email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True) + password_hash: Mapped[str] = mapped_column(String(255), nullable=False) + tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free") + stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + refresh_tokens: Mapped[list[RefreshToken]] = relationship( + back_populates="user", cascade="all, delete-orphan" + ) + subscription: Mapped[Subscription | None] = relationship( + back_populates="user", uselist=False, cascade="all, delete-orphan" + ) + + +class RefreshToken(Base): + __tablename__ = "refresh_tokens" + + id: Mapped[str] = mapped_column( + UUID(as_uuid=False), primary_key=True, default=_uuid + ) + user_id: Mapped[str] = mapped_column( + UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True + ) + token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True) + expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + + user: Mapped[User] = relationship(back_populates="refresh_tokens") + + +class Subscription(Base): + __tablename__ = "subscriptions" + + id: Mapped[str] = mapped_column( + UUID(as_uuid=False), primary_key=True, default=_uuid + ) + user_id: Mapped[str] = mapped_column( + UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, unique=True, index=True + ) + stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True) + tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free") + status: Mapped[str] = mapped_column(String(50), nullable=False, default="free") + current_period_end: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + + user: Mapped[User] = relationship(back_populates="subscription") + + +class StorageRecord(Base): + __tablename__ = "storage_records" + + id: Mapped[str] = mapped_column( + UUID(as_uuid=False), primary_key=True, default=_uuid + ) + user_id: Mapped[str] = mapped_column( + UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True + ) + table_name: Mapped[str] = mapped_column(String(100), nullable=False) + s3_key: Mapped[str] = mapped_column(String(500), nullable=False) + checksum: Mapped[str] = mapped_column(String(64), nullable=False) + size_bytes: Mapped[int] = mapped_column(Integer, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() + ) + + +class BackupMetadata(Base): + __tablename__ = "backup_metadata" + + id: Mapped[str] = mapped_column( + UUID(as_uuid=False), primary_key=True, default=_uuid + ) + user_id: Mapped[str] = mapped_column( + UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True + ) + s3_key: Mapped[str] = mapped_column(String(500), nullable=False) + version: Mapped[int] = mapped_column(Integer, nullable=False) + timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False) + checksum: Mapped[str] = mapped_column(String(64), nullable=False) + size_bytes: Mapped[int] = mapped_column(Integer, nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + + +class Plugin(Base): + __tablename__ = "plugins" + + id: Mapped[str] = mapped_column(String(255), primary_key=True) + name: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str] = mapped_column(Text, nullable=False, default="") + version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0") + # nullable until developer account system is built + author_id: Mapped[str | None] = mapped_column( + UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True + ) + author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="") + category: Mapped[str] = mapped_column(String(100), nullable=False, default="") + price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list + status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review") + s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True) + install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) + rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True) + submitted_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + + installations: Mapped[list[PluginInstallation]] = relationship( + back_populates="plugin", cascade="all, delete-orphan" + ) + reviews: Mapped[list[PluginReview]] = relationship( + back_populates="plugin", cascade="all, delete-orphan" + ) + revenue_events: Mapped[list[RevenueEvent]] = relationship( + back_populates="plugin", cascade="all, delete-orphan" + ) + + +class PluginInstallation(Base): + __tablename__ = "plugin_installations" + __table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),) + + id: Mapped[str] = mapped_column( + UUID(as_uuid=False), primary_key=True, default=_uuid + ) + plugin_id: Mapped[str] = mapped_column( + String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True + ) + user_id: Mapped[str] = mapped_column( + UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True + ) + installed_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + + plugin: Mapped[Plugin] = relationship(back_populates="installations") + + +class PluginReview(Base): + __tablename__ = "plugin_reviews" + + id: Mapped[str] = mapped_column( + UUID(as_uuid=False), primary_key=True, default=_uuid + ) + plugin_id: Mapped[str] = mapped_column( + String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True + ) + reviewer_id: Mapped[str | None] = mapped_column( + UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True + ) + decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False) + notes: Mapped[str | None] = mapped_column(Text, nullable=True) + reviewed_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + + plugin: Mapped[Plugin] = relationship(back_populates="reviews") + + +class RevenueEvent(Base): + __tablename__ = "revenue_events" + + id: Mapped[str] = mapped_column( + UUID(as_uuid=False), primary_key=True, default=_uuid + ) + plugin_id: Mapped[str] = mapped_column( + String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True + ) + user_id: Mapped[str] = mapped_column( + UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True + ) + amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True) + paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + + plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")