step 12
This commit is contained in:
47
alembic.ini
Normal file
47
alembic.ini
Normal file
@@ -0,0 +1,47 @@
|
||||
# Alembic configuration file.
|
||||
# The async app uses postgresql+asyncpg:// at runtime.
|
||||
# Alembic CLI uses the sync psycopg2 URL set in env.py (reads from DATABASE_URL env var).
|
||||
|
||||
[alembic]
|
||||
script_location = alembic
|
||||
prepend_sys_path = .
|
||||
version_path_separator = os
|
||||
|
||||
# sqlalchemy.url is overridden in alembic/env.py — leave as placeholder.
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
93
alembic/env.py
Normal file
93
alembic/env.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Alembic migration environment — async-compatible.
|
||||
|
||||
At runtime the app uses ``postgresql+asyncpg://``. Alembic's CLI is
|
||||
synchronous, so we derive a *sync* psycopg2 URL from the same DATABASE_URL
|
||||
env var by replacing the driver prefix.
|
||||
|
||||
Run migrations with:
|
||||
alembic upgrade head
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
# Alembic Config object (gives access to alembic.ini values).
|
||||
config = context.config
|
||||
|
||||
# Set up Python logging from alembic.ini.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Import the Base so that Alembic can detect model changes for --autogenerate.
|
||||
from app.models import Base # noqa: E402
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def _sync_url(async_url: str) -> str:
|
||||
"""Convert an asyncpg URL to a psycopg2 URL for Alembic CLI."""
|
||||
return re.sub(r"postgresql\+asyncpg", "postgresql+psycopg2", async_url)
|
||||
|
||||
|
||||
def _get_url() -> str:
|
||||
db_url = os.environ.get("DATABASE_URL", "")
|
||||
if not db_url:
|
||||
# Fall back to settings if env var not set directly.
|
||||
from app.config.settings import settings # noqa: PLC0415
|
||||
db_url = settings.DATABASE_URL
|
||||
return _sync_url(db_url)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Emit SQL without a live DB connection."""
|
||||
url = _get_url()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
compare_type=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection): # type: ignore[no-untyped-def]
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_migrations_online_async() -> None:
|
||||
"""Run migrations against a live DB using the async engine."""
|
||||
async_url = os.environ.get("DATABASE_URL", "")
|
||||
if not async_url:
|
||||
from app.config.settings import settings # noqa: PLC0415
|
||||
async_url = settings.DATABASE_URL
|
||||
|
||||
connectable = create_async_engine(async_url, poolclass=pool.NullPool)
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
asyncio.run(run_migrations_online_async())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
28
alembic/script.py.mako
Normal file
28
alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
202
alembic/versions/001_initial_schema.py
Normal file
202
alembic/versions/001_initial_schema.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Initial schema: users, refresh_tokens, subscriptions, storage_records,
|
||||
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
|
||||
|
||||
Revision ID: 001
|
||||
Revises:
|
||||
Create Date: 2026-03-02
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "001"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── Enum types ────────────────────────────────────────────────────────
|
||||
billing_tier = postgresql.ENUM(
|
||||
"free", "pro", "power", "team", name="billing_tier", create_type=False
|
||||
)
|
||||
plugin_status = postgresql.ENUM(
|
||||
"pending_review", "approved", "rejected", name="plugin_status", create_type=False
|
||||
)
|
||||
review_decision = postgresql.ENUM(
|
||||
"approved", "rejected", name="review_decision", create_type=False
|
||||
)
|
||||
for enum in (billing_tier, plugin_status, review_decision):
|
||||
enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# ── users ─────────────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("email", sa.String(255), nullable=False),
|
||||
sa.Column("password_hash", sa.String(255), nullable=False),
|
||||
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier"), nullable=False, server_default="free"),
|
||||
sa.Column("stripe_customer_id", sa.String(255), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("email"),
|
||||
)
|
||||
op.create_index("ix_users_email", "users", ["email"])
|
||||
|
||||
# ── refresh_tokens ────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"refresh_tokens",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("token_hash", sa.String(64), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.UniqueConstraint("token_hash"),
|
||||
)
|
||||
op.create_index("ix_refresh_tokens_user_id", "refresh_tokens", ["user_id"])
|
||||
op.create_index("ix_refresh_tokens_token_hash", "refresh_tokens", ["token_hash"])
|
||||
|
||||
# ── subscriptions ─────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"subscriptions",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("stripe_subscription_id", sa.String(255), nullable=True),
|
||||
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier"), nullable=False, server_default="free"),
|
||||
sa.Column("status", sa.String(50), nullable=False, server_default="free"),
|
||||
sa.Column("current_period_end", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.UniqueConstraint("user_id"),
|
||||
)
|
||||
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
|
||||
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
|
||||
|
||||
# ── storage_records ───────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"storage_records",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("table_name", sa.String(100), nullable=False),
|
||||
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||
sa.Column("checksum", sa.String(64), nullable=False),
|
||||
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
|
||||
|
||||
# ── backup_metadata ───────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"backup_metadata",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||
sa.Column("version", sa.Integer, nullable=False),
|
||||
sa.Column("timestamp", sa.BigInteger, nullable=False),
|
||||
sa.Column("checksum", sa.String(64), nullable=False),
|
||||
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
|
||||
|
||||
# ── plugins ───────────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"plugins",
|
||||
sa.Column("id", sa.String(255), nullable=False),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("description", sa.Text, nullable=False, server_default=""),
|
||||
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
|
||||
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
|
||||
sa.Column("category", sa.String(100), nullable=False, server_default=""),
|
||||
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
|
||||
sa.Column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status"), nullable=False, server_default="pending_review"),
|
||||
sa.Column("s3_package_key", sa.String(500), nullable=True),
|
||||
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
|
||||
sa.Column("rejection_reason", sa.Text, nullable=True),
|
||||
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
|
||||
)
|
||||
|
||||
# ── plugin_installations ──────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"plugin_installations",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
|
||||
)
|
||||
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
|
||||
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
|
||||
|
||||
# ── plugin_reviews ────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"plugin_reviews",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||
sa.Column("decision", sa.Enum("approved", "rejected", name="review_decision"), nullable=False),
|
||||
sa.Column("notes", sa.Text, nullable=True),
|
||||
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
|
||||
)
|
||||
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
|
||||
|
||||
# ── revenue_events ────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"revenue_events",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
|
||||
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
|
||||
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("revenue_events")
|
||||
op.drop_table("plugin_reviews")
|
||||
op.drop_table("plugin_installations")
|
||||
op.drop_table("plugins")
|
||||
op.drop_table("backup_metadata")
|
||||
op.drop_table("storage_records")
|
||||
op.drop_table("subscriptions")
|
||||
op.drop_table("refresh_tokens")
|
||||
op.drop_table("users")
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS review_decision")
|
||||
op.execute("DROP TYPE IF EXISTS plugin_status")
|
||||
op.execute("DROP TYPE IF EXISTS billing_tier")
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Auth middleware — JWT validation dependency.
|
||||
|
||||
``get_current_user`` is the FastAPI dependency used by all protected routes.
|
||||
It decodes the Bearer JWT, validates signature and expiry, and returns a
|
||||
``UserProfile`` carrying ``id``, ``email``, and ``tier``.
|
||||
It decodes the Bearer JWT (identity + expiry), then fetches the current tier
|
||||
from the ``subscriptions`` table so that tier changes take effect immediately
|
||||
without requiring token re-issue.
|
||||
|
||||
Exempt routes (no JWT required):
|
||||
- POST /api/v1/auth/register
|
||||
@@ -15,8 +16,11 @@ from __future__ import annotations
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.db import get_session
|
||||
from app.schemas import UserProfile
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
@@ -24,12 +28,15 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Validate a Bearer JWT and return the authenticated user.
|
||||
|
||||
The JWT is used for identity and expiry only. The tier is fetched live
|
||||
from the ``subscriptions`` table so that upgrades/downgrades take effect
|
||||
immediately. Falls back to ``'free'`` when no subscription row exists.
|
||||
|
||||
Raises HTTP 401 on any invalid or expired token.
|
||||
The tier embedded in the JWT is used for feature-gating until Step 12
|
||||
adds a live DB lookup.
|
||||
"""
|
||||
credentials_exc = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -42,10 +49,17 @@ async def get_current_user(
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
email: str | None = payload.get("email")
|
||||
tier: str = payload.get("tier", "free")
|
||||
if not user_id or not email:
|
||||
raise credentials_exc
|
||||
except JWTError:
|
||||
raise credentials_exc
|
||||
|
||||
# Live tier lookup — subscription row is the authoritative source.
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
tier: str = result.scalar_one_or_none() or "free"
|
||||
|
||||
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]
|
||||
|
||||
@@ -1,33 +1,36 @@
|
||||
"""Auth routes: register, login, refresh, me.
|
||||
|
||||
Users and refresh tokens are kept in an in-memory dict until Step 12
|
||||
migrates them to PostgreSQL.
|
||||
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
||||
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
||||
SHA-256 hashes so plaintext never reaches the DB.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import bcrypt
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.config.settings import settings
|
||||
from app.db import get_session
|
||||
from app.models import RefreshToken, User
|
||||
from app.schemas import AuthTokens, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
# ── In-memory stores (replaced by PostgreSQL in Step 12) ─────────────
|
||||
_users: dict[str, dict[str, Any]] = {} # email → user record
|
||||
_refresh_tokens: dict[str, str] = {} # plain token → user_id
|
||||
|
||||
|
||||
# ── Internal helpers ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
@@ -36,30 +39,29 @@ def _verify_password(password: str, hashed: str) -> bool:
|
||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||
|
||||
|
||||
def _make_tokens(user_id: str, email: str, tier: str) -> AuthTokens:
|
||||
def _hash_token(plain_token: str) -> str:
|
||||
"""SHA-256 of the plain refresh token string."""
|
||||
return hashlib.sha256(plain_token.encode()).hexdigest()
|
||||
|
||||
|
||||
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
||||
"""Return (signed JWT, expires_at_ms)."""
|
||||
now = int(time.time())
|
||||
access_exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
access_payload = {
|
||||
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"email": email,
|
||||
"tier": tier,
|
||||
"exp": access_exp,
|
||||
"exp": exp,
|
||||
"iat": now,
|
||||
}
|
||||
access_token = jwt.encode(
|
||||
access_payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
refresh_token = str(uuid.uuid4())
|
||||
_refresh_tokens[refresh_token] = user_id
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_at=access_exp * 1000, # milliseconds for client
|
||||
)
|
||||
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||
return token, exp * 1000 # ms for client
|
||||
|
||||
|
||||
# ── Request bodies ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _RegisterRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
@@ -76,40 +78,117 @@ class _RefreshRequest(BaseModel):
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
||||
async def register(body: _RegisterRequest) -> AuthTokens:
|
||||
async def register(
|
||||
body: _RegisterRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Create a new account and return JWT tokens."""
|
||||
if body.email in _users:
|
||||
existing = await db.execute(select(User).where(User.email == body.email))
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
||||
user_id = str(uuid.uuid4())
|
||||
_users[body.email] = {
|
||||
"id": user_id,
|
||||
"email": body.email,
|
||||
"password_hash": _hash_password(body.password),
|
||||
"tier": "free",
|
||||
}
|
||||
return _make_tokens(user_id, body.email, "free")
|
||||
|
||||
user = User(
|
||||
id=str(uuid.uuid4()),
|
||||
email=body.email,
|
||||
password_hash=_hash_password(body.password),
|
||||
tier="free",
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush() # get user.id without committing
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=AuthTokens)
|
||||
async def login(body: _LoginRequest) -> AuthTokens:
|
||||
async def login(
|
||||
body: _LoginRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Validate credentials and return JWT tokens."""
|
||||
user = _users.get(body.email)
|
||||
if not user or not _verify_password(body.password, user["password_hash"]):
|
||||
result = await db.execute(select(User).where(User.email == body.email))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not _verify_password(body.password, user.password_hash):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||
return _make_tokens(user["id"], user["email"], user["tier"])
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=AuthTokens)
|
||||
async def refresh(body: _RefreshRequest) -> AuthTokens:
|
||||
async def refresh(
|
||||
body: _RefreshRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Rotate a refresh token and return a new token pair."""
|
||||
user_id = _refresh_tokens.pop(body.refresh_token, None)
|
||||
if user_id is None:
|
||||
token_hash = _hash_token(body.refresh_token)
|
||||
result = await db.execute(
|
||||
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||
)
|
||||
rt = result.scalar_one_or_none()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
||||
user = next((u for u in _users.values() if u["id"] == user_id), None)
|
||||
|
||||
# Rotate: delete old token, issue new one.
|
||||
await db.delete(rt)
|
||||
|
||||
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
||||
return _make_tokens(user["id"], user["email"], user["tier"])
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
new_rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=new_expires,
|
||||
)
|
||||
db.add(new_rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserProfile)
|
||||
|
||||
@@ -11,9 +11,11 @@ from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Request, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.stripe_service import stripe_service
|
||||
from app.db import get_session
|
||||
from app.schemas import BillingTier, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||
@@ -44,6 +46,7 @@ async def create_checkout(
|
||||
async def stripe_webhook(
|
||||
request: Request,
|
||||
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Handle Stripe webhook events.
|
||||
|
||||
@@ -51,16 +54,17 @@ async def stripe_webhook(
|
||||
Returns 200 immediately when Stripe is not configured (local dev).
|
||||
"""
|
||||
payload = await request.body()
|
||||
stripe_service.handle_webhook(payload, stripe_signature)
|
||||
await stripe_service.handle_webhook(payload, stripe_signature, db)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/subscription", response_model=dict)
|
||||
async def get_subscription(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, Any]:
|
||||
"""Return the current subscription info for the authenticated user."""
|
||||
sub = stripe_service.get_subscription(current_user.id)
|
||||
sub = await stripe_service.get_subscription(current_user.id, db)
|
||||
if sub is None:
|
||||
return {
|
||||
"tier": current_user.tier,
|
||||
@@ -74,7 +78,8 @@ async def get_subscription(
|
||||
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
|
||||
async def cancel_subscription(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Cancel the active subscription."""
|
||||
stripe_service.cancel_subscription(current_user.id)
|
||||
await stripe_service.cancel_subscription(current_user.id, db)
|
||||
return {"ok": True}
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
"""Stripe service: checkout sessions, webhook handling, subscription management.
|
||||
|
||||
Subscriptions are stored in-memory until Step 12 migrates them to the
|
||||
PostgreSQL ``subscriptions`` table. All Stripe calls are gracefully stubbed
|
||||
when ``STRIPE_SECRET_KEY`` is not configured, enabling local development
|
||||
without live credentials.
|
||||
Subscription records are persisted in the PostgreSQL ``subscriptions`` table.
|
||||
All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not
|
||||
configured, enabling local development without live credentials.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import stripe as stripe_lib
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
@@ -24,15 +26,7 @@ TIER_PRICE_IDS: dict[str, str] = {
|
||||
|
||||
|
||||
class StripeService:
|
||||
"""Wraps all Stripe interactions and owns the in-memory subscription store.
|
||||
|
||||
Step 12 will replace ``_subscriptions`` with real PostgreSQL queries.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# user_id → subscription record dict
|
||||
# Replaced by the ``subscriptions`` table in Step 12.
|
||||
self._subscriptions: dict[str, dict[str, Any]] = {}
|
||||
"""Wraps all Stripe interactions and owns subscription persistence."""
|
||||
|
||||
# ── Internal helpers ────────────────────────────────────────────────
|
||||
|
||||
@@ -84,7 +78,12 @@ class StripeService:
|
||||
)
|
||||
return session.url
|
||||
|
||||
def handle_webhook(self, payload: bytes, sig_header: str) -> None:
|
||||
async def handle_webhook(
|
||||
self,
|
||||
payload: bytes,
|
||||
sig_header: str,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Process a Stripe webhook event.
|
||||
|
||||
Verifies the signature, then dispatches on event type.
|
||||
@@ -112,57 +111,82 @@ class StripeService:
|
||||
user_id = data.get("metadata", {}).get("user_id")
|
||||
tier = data.get("metadata", {}).get("tier", "free")
|
||||
sub_id = data.get("subscription")
|
||||
period_end = data.get("current_period_end")
|
||||
period_end_ts = data.get("current_period_end")
|
||||
period_end = (
|
||||
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||
if period_end_ts
|
||||
else None
|
||||
)
|
||||
if user_id:
|
||||
self._subscriptions[user_id] = {
|
||||
"tier": tier,
|
||||
"stripe_subscription_id": sub_id,
|
||||
"status": "active",
|
||||
"current_period_end": period_end,
|
||||
}
|
||||
await self._upsert_subscription(
|
||||
db, user_id, sub_id, tier, "active", period_end
|
||||
)
|
||||
|
||||
elif event_type == "customer.subscription.updated":
|
||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, update tier
|
||||
sub_id = data.get("id")
|
||||
new_status = data.get("status")
|
||||
period_end = data.get("current_period_end")
|
||||
for record in self._subscriptions.values():
|
||||
if record.get("stripe_subscription_id") == sub_id:
|
||||
record["status"] = new_status
|
||||
record["current_period_end"] = period_end
|
||||
break
|
||||
new_status = data.get("status", "active")
|
||||
period_end_ts = data.get("current_period_end")
|
||||
period_end = (
|
||||
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||
if period_end_ts
|
||||
else None
|
||||
)
|
||||
if sub_id:
|
||||
await self._update_subscription_by_stripe_id(
|
||||
db, sub_id, status=new_status, current_period_end=period_end
|
||||
)
|
||||
|
||||
elif event_type == "customer.subscription.deleted":
|
||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free
|
||||
sub_id = data.get("id")
|
||||
for user_id, record in self._subscriptions.items():
|
||||
if record.get("stripe_subscription_id") == sub_id:
|
||||
self._subscriptions[user_id] = {
|
||||
**record,
|
||||
"tier": "free",
|
||||
"status": "canceled",
|
||||
}
|
||||
break
|
||||
if sub_id:
|
||||
await self._update_subscription_by_stripe_id(
|
||||
db, sub_id, tier="free", status="canceled"
|
||||
)
|
||||
|
||||
elif event_type == "invoice.payment_failed":
|
||||
# TODO(Step12): flag subscription as past_due, notify user
|
||||
sub_id = data.get("subscription")
|
||||
for record in self._subscriptions.values():
|
||||
if record.get("stripe_subscription_id") == sub_id:
|
||||
record["status"] = "past_due"
|
||||
break
|
||||
if sub_id:
|
||||
await self._update_subscription_by_stripe_id(
|
||||
db, sub_id, status="past_due"
|
||||
)
|
||||
|
||||
def get_subscription(self, user_id: str) -> dict[str, Any] | None:
|
||||
await db.commit()
|
||||
|
||||
async def get_subscription(
|
||||
self, user_id: str, db: AsyncSession
|
||||
) -> dict[str, Any] | None:
|
||||
"""Return the subscription record for ``user_id``, or ``None`` if absent."""
|
||||
return self._subscriptions.get(user_id)
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
def cancel_subscription(self, user_id: str) -> None:
|
||||
result = await db.execute(
|
||||
select(Subscription).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None:
|
||||
return None
|
||||
return {
|
||||
"tier": sub.tier,
|
||||
"stripe_subscription_id": sub.stripe_subscription_id,
|
||||
"status": sub.status,
|
||||
"current_period_end": (
|
||||
int(sub.current_period_end.timestamp() * 1000)
|
||||
if sub.current_period_end
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
|
||||
"""Cancel the user's Stripe subscription and downgrade them to free.
|
||||
|
||||
Raises ``HTTP 404`` when no active subscription exists.
|
||||
"""
|
||||
sub = self._subscriptions.get(user_id)
|
||||
if sub is None or not sub.get("stripe_subscription_id"):
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None or not sub.stripe_subscription_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No active subscription found",
|
||||
@@ -170,13 +194,62 @@ class StripeService:
|
||||
|
||||
if self._configured():
|
||||
s = self._client()
|
||||
s.Subscription.cancel(sub["stripe_subscription_id"])
|
||||
s.Subscription.cancel(sub.stripe_subscription_id)
|
||||
|
||||
self._subscriptions[user_id] = {
|
||||
**sub,
|
||||
"tier": "free",
|
||||
"status": "canceled",
|
||||
}
|
||||
sub.tier = "free"
|
||||
sub.status = "canceled"
|
||||
await db.commit()
|
||||
|
||||
# ── Private DB helpers ───────────────────────────────────────────────
|
||||
|
||||
async def _upsert_subscription(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
stripe_subscription_id: str | None,
|
||||
tier: str,
|
||||
sub_status: str,
|
||||
current_period_end: datetime | None,
|
||||
) -> None:
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None:
|
||||
sub = Subscription(user_id=user_id)
|
||||
db.add(sub)
|
||||
sub.stripe_subscription_id = stripe_subscription_id
|
||||
sub.tier = tier
|
||||
sub.status = sub_status
|
||||
sub.current_period_end = current_period_end
|
||||
|
||||
async def _update_subscription_by_stripe_id(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
stripe_subscription_id: str,
|
||||
*,
|
||||
tier: str | None = None,
|
||||
status: str | None = None,
|
||||
current_period_end: datetime | None = None,
|
||||
) -> None:
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(
|
||||
Subscription.stripe_subscription_id == stripe_subscription_id
|
||||
)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None:
|
||||
return
|
||||
if tier is not None:
|
||||
sub.tier = tier
|
||||
if status is not None:
|
||||
sub.status = status
|
||||
if current_period_end is not None:
|
||||
sub.current_period_end = current_period_end
|
||||
|
||||
|
||||
# Module-level singleton shared across the app.
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Tier manager: feature matrix and quota enforcement.
|
||||
|
||||
``TierManager`` is the single source of truth for what each billing tier
|
||||
allows. ``get_tier`` reads from the ``StripeService`` in-memory store until
|
||||
Step 12 replaces it with a live PostgreSQL lookup.
|
||||
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
|
||||
Quota-enforcement helpers take ``tier`` directly — the caller already has it
|
||||
from ``current_user.tier`` (provided by ``get_current_user``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -10,6 +11,8 @@ from __future__ import annotations
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.schemas import BillingTier
|
||||
|
||||
@@ -67,55 +70,42 @@ RATE_LIMITS: dict[str, int] = {
|
||||
|
||||
|
||||
class TierManager:
|
||||
"""Centralises tier feature-gating, rate-limit lookups, and quota checks.
|
||||
|
||||
``get_tier`` consults the ``StripeService`` singleton. Step 12 will
|
||||
replace that with a PostgreSQL query so that the tier is always fresh.
|
||||
"""
|
||||
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
|
||||
|
||||
# ── Tier lookup ─────────────────────────────────────────────────────
|
||||
|
||||
def get_tier(self, user_id: str) -> BillingTier:
|
||||
"""Return the current billing tier for ``user_id``.
|
||||
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||
"""Return the current billing tier for ``user_id`` from the DB.
|
||||
|
||||
Falls back to ``'free'`` when no subscription record exists.
|
||||
Step 12 will replace this with a live DB lookup.
|
||||
Falls back to ``'free'`` when no subscription row exists.
|
||||
"""
|
||||
# Import here to avoid circular imports at module load time.
|
||||
from app.billing.stripe_service import stripe_service # noqa: PLC0415
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
sub = stripe_service.get_subscription(user_id)
|
||||
if sub is None:
|
||||
return "free"
|
||||
tier = sub.get("tier", "free")
|
||||
# Validate against known tiers; unknown values fall back to free.
|
||||
if tier not in FEATURES:
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
tier: str | None = result.scalar_one_or_none()
|
||||
if tier is None or tier not in FEATURES:
|
||||
return "free"
|
||||
return tier # type: ignore[return-value]
|
||||
|
||||
# ── Feature access ───────────────────────────────────────────────────
|
||||
|
||||
def check_feature(self, user_id: str, feature: str) -> bool:
|
||||
"""Return ``True`` if ``user_id``'s current tier has ``feature`` enabled.
|
||||
def check_feature(self, tier: BillingTier, feature: str) -> bool:
|
||||
"""Return ``True`` if ``tier`` has ``feature`` enabled.
|
||||
|
||||
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
|
||||
"""
|
||||
tier = self.get_tier(user_id)
|
||||
value = FEATURES[tier].get(feature)
|
||||
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
# Numeric: -1 means unlimited (enabled), 0 means disabled.
|
||||
return value != 0
|
||||
|
||||
def require_feature(self, user_id: str, feature: str, tier_name: str = "") -> None:
|
||||
"""Raise ``HTTP 403`` if ``user_id`` does not have ``feature``.
|
||||
|
||||
``tier_name`` is used in the error message to tell users which tier
|
||||
they need to upgrade to.
|
||||
"""
|
||||
if not self.check_feature(user_id, feature):
|
||||
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
|
||||
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
|
||||
if not self.check_feature(tier, feature):
|
||||
detail = (
|
||||
f"Feature '{feature}' requires {tier_name} tier or above."
|
||||
if tier_name
|
||||
@@ -131,39 +121,17 @@ class TierManager:
|
||||
|
||||
# ── Storage quota ────────────────────────────────────────────────────
|
||||
|
||||
def check_quota(
|
||||
self,
|
||||
user_id: str,
|
||||
current_bytes: int = 0,
|
||||
additional_bytes: int = 0,
|
||||
) -> bool:
|
||||
"""Return ``True`` if ``user_id`` can store ``additional_bytes`` more data.
|
||||
|
||||
``current_bytes`` is the user's current storage usage (from the
|
||||
caller's record-keeping). Step 12 will remove these parameters and
|
||||
query the DB directly.
|
||||
|
||||
Returns ``False`` if the tier has no storage allocation at all
|
||||
(free tier), or if ``current_bytes + additional_bytes`` would exceed
|
||||
the tier's ``cloud_storage_gb`` limit.
|
||||
"""
|
||||
tier = self.get_tier(user_id)
|
||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||
if limit_gb == 0:
|
||||
return False # tier has no storage
|
||||
if limit_gb == -1:
|
||||
return True # unlimited
|
||||
limit_bytes = limit_gb * 1024 ** 3
|
||||
return current_bytes + additional_bytes <= limit_bytes
|
||||
|
||||
def enforce_quota(
|
||||
self,
|
||||
user_id: str,
|
||||
tier: BillingTier,
|
||||
current_bytes: int = 0,
|
||||
additional_bytes: int = 0,
|
||||
) -> None:
|
||||
"""Raise ``HTTP 402`` if ``user_id`` would exceed their storage quota."""
|
||||
tier = self.get_tier(user_id)
|
||||
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
|
||||
|
||||
``tier`` is the caller's current tier (from ``current_user.tier``).
|
||||
``current_bytes`` is the total bytes already stored (queried by caller).
|
||||
"""
|
||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||
if limit_gb == 0:
|
||||
raise HTTPException(
|
||||
@@ -181,12 +149,11 @@ class TierManager:
|
||||
|
||||
def enforce_backup_quota(
|
||||
self,
|
||||
user_id: str,
|
||||
tier: BillingTier,
|
||||
current_bytes: int = 0,
|
||||
additional_bytes: int = 0,
|
||||
) -> None:
|
||||
"""Raise ``HTTP 402`` if ``user_id`` would exceed their backup quota."""
|
||||
tier = self.get_tier(user_id)
|
||||
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
|
||||
limit_gb: int = FEATURES[tier]["backup_gb"]
|
||||
if limit_gb == 0:
|
||||
raise HTTPException(
|
||||
@@ -202,6 +169,21 @@ class TierManager:
|
||||
detail=f"Backup quota exceeded for tier '{tier}'",
|
||||
)
|
||||
|
||||
def check_quota(
|
||||
self,
|
||||
tier: BillingTier,
|
||||
current_bytes: int = 0,
|
||||
additional_bytes: int = 0,
|
||||
) -> bool:
|
||||
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
|
||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||
if limit_gb == 0:
|
||||
return False
|
||||
if limit_gb == -1:
|
||||
return True
|
||||
limit_bytes = limit_gb * 1024 ** 3
|
||||
return current_bytes + additional_bytes <= limit_bytes
|
||||
|
||||
|
||||
# Module-level singleton shared across the app.
|
||||
tier_manager = TierManager()
|
||||
|
||||
40
app/db.py
Normal file
40
app/db.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Database engine, session factory, and base model.
|
||||
|
||||
All app code uses the async SQLAlchemy API. Alembic migrations use the
|
||||
synchronous psycopg2 URL for the CLI (see alembic/env.py).
|
||||
|
||||
Usage in routes:
|
||||
from app.db import get_session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
async def my_route(db: AsyncSession = Depends(get_session)):
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_pre_ping=True,
|
||||
echo=settings.ENV == "dev",
|
||||
)
|
||||
|
||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Shared declarative base for all ORM models."""
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""FastAPI dependency that yields an async DB session per request."""
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
@@ -16,7 +16,9 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown: nothing to clean up for now
|
||||
# Shutdown: dispose SQLAlchemy connection pool
|
||||
from app.db import engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
|
||||
269
app/models.py
Normal file
269
app/models.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""SQLAlchemy ORM models for all persistent tables.
|
||||
|
||||
Only auth, billing, storage metadata, and marketplace data live here.
|
||||
User content (notes, tasks, etc.) is NEVER persisted server-side —
|
||||
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
|
||||
|
||||
Table inventory:
|
||||
users — account credentials + tier
|
||||
refresh_tokens — hashed refresh token store
|
||||
subscriptions — Stripe subscription records
|
||||
storage_records — S3 blob metadata (no plaintext)
|
||||
backup_metadata — encrypted backup manifests
|
||||
plugins — marketplace plugin catalog
|
||||
plugin_installations — per-user install records
|
||||
plugin_reviews — admin review decisions
|
||||
revenue_events — Stripe Connect 70/30 split ledger
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Boolean,
|
||||
DateTime,
|
||||
Enum,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db import Base
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _uuid() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
# ── Enum types ────────────────────────────────────────────────────────────
|
||||
|
||||
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
||||
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
|
||||
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
|
||||
|
||||
|
||||
# ── Models ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
refresh_tokens: Mapped[list[RefreshToken]] = relationship(
|
||||
back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
subscription: Mapped[Subscription | None] = relationship(
|
||||
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class RefreshToken(Base):
|
||||
__tablename__ = "refresh_tokens"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
||||
|
||||
|
||||
class Subscription(Base):
|
||||
__tablename__ = "subscriptions"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, unique=True, index=True
|
||||
)
|
||||
stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
|
||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||
status: Mapped[str] = mapped_column(String(50), nullable=False, default="free")
|
||||
current_period_end: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship(back_populates="subscription")
|
||||
|
||||
|
||||
class StorageRecord(Base):
|
||||
__tablename__ = "storage_records"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class BackupMetadata(Base):
|
||||
__tablename__ = "backup_metadata"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
class Plugin(Base):
|
||||
__tablename__ = "plugins"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
|
||||
# nullable until developer account system is built
|
||||
author_id: Mapped[str | None] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
|
||||
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
|
||||
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
|
||||
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
|
||||
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
submitted_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
installations: Mapped[list[PluginInstallation]] = relationship(
|
||||
back_populates="plugin", cascade="all, delete-orphan"
|
||||
)
|
||||
reviews: Mapped[list[PluginReview]] = relationship(
|
||||
back_populates="plugin", cascade="all, delete-orphan"
|
||||
)
|
||||
revenue_events: Mapped[list[RevenueEvent]] = relationship(
|
||||
back_populates="plugin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class PluginInstallation(Base):
|
||||
__tablename__ = "plugin_installations"
|
||||
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
plugin_id: Mapped[str] = mapped_column(
|
||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
installed_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
plugin: Mapped[Plugin] = relationship(back_populates="installations")
|
||||
|
||||
|
||||
class PluginReview(Base):
|
||||
__tablename__ = "plugin_reviews"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
plugin_id: Mapped[str] = mapped_column(
|
||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
reviewer_id: Mapped[str | None] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
|
||||
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
reviewed_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
|
||||
|
||||
|
||||
class RevenueEvent(Base):
|
||||
__tablename__ = "revenue_events"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
plugin_id: Mapped[str] = mapped_column(
|
||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
|
||||
Reference in New Issue
Block a user