step 12
This commit is contained in:
47
alembic.ini
Normal file
47
alembic.ini
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# Alembic configuration file.
|
||||||
|
# The async app uses postgresql+asyncpg:// at runtime.
|
||||||
|
# Alembic CLI uses the sync psycopg2 URL set in env.py (reads from DATABASE_URL env var).
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
script_location = alembic
|
||||||
|
prepend_sys_path = .
|
||||||
|
version_path_separator = os
|
||||||
|
|
||||||
|
# sqlalchemy.url is overridden in alembic/env.py — leave as placeholder.
|
||||||
|
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARN
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARN
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
93
alembic/env.py
Normal file
93
alembic/env.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""Alembic migration environment — async-compatible.
|
||||||
|
|
||||||
|
At runtime the app uses ``postgresql+asyncpg://``. Alembic's CLI is
|
||||||
|
synchronous, so we derive a *sync* psycopg2 URL from the same DATABASE_URL
|
||||||
|
env var by replacing the driver prefix.
|
||||||
|
|
||||||
|
Run migrations with:
|
||||||
|
alembic upgrade head
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from sqlalchemy import engine_from_config, pool
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
|
||||||
|
# Alembic Config object (gives access to alembic.ini values).
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Set up Python logging from alembic.ini.
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# Import the Base so that Alembic can detect model changes for --autogenerate.
|
||||||
|
from app.models import Base # noqa: E402
|
||||||
|
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_url(async_url: str) -> str:
|
||||||
|
"""Convert an asyncpg URL to a psycopg2 URL for Alembic CLI."""
|
||||||
|
return re.sub(r"postgresql\+asyncpg", "postgresql+psycopg2", async_url)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_url() -> str:
|
||||||
|
db_url = os.environ.get("DATABASE_URL", "")
|
||||||
|
if not db_url:
|
||||||
|
# Fall back to settings if env var not set directly.
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
db_url = settings.DATABASE_URL
|
||||||
|
return _sync_url(db_url)
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Emit SQL without a live DB connection."""
|
||||||
|
url = _get_url()
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def do_run_migrations(connection): # type: ignore[no-untyped-def]
|
||||||
|
context.configure(
|
||||||
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_migrations_online_async() -> None:
|
||||||
|
"""Run migrations against a live DB using the async engine."""
|
||||||
|
async_url = os.environ.get("DATABASE_URL", "")
|
||||||
|
if not async_url:
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
async_url = settings.DATABASE_URL
|
||||||
|
|
||||||
|
connectable = create_async_engine(async_url, poolclass=pool.NullPool)
|
||||||
|
async with connectable.connect() as connection:
|
||||||
|
await connection.run_sync(do_run_migrations)
|
||||||
|
await connectable.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
asyncio.run(run_migrations_online_async())
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
28
alembic/script.py.mako
Normal file
28
alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
202
alembic/versions/001_initial_schema.py
Normal file
202
alembic/versions/001_initial_schema.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
"""Initial schema: users, refresh_tokens, subscriptions, storage_records,
|
||||||
|
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
|
||||||
|
|
||||||
|
Revision ID: 001
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-03-02
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "001"
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ── Enum types ────────────────────────────────────────────────────────
|
||||||
|
billing_tier = postgresql.ENUM(
|
||||||
|
"free", "pro", "power", "team", name="billing_tier", create_type=False
|
||||||
|
)
|
||||||
|
plugin_status = postgresql.ENUM(
|
||||||
|
"pending_review", "approved", "rejected", name="plugin_status", create_type=False
|
||||||
|
)
|
||||||
|
review_decision = postgresql.ENUM(
|
||||||
|
"approved", "rejected", name="review_decision", create_type=False
|
||||||
|
)
|
||||||
|
for enum in (billing_tier, plugin_status, review_decision):
|
||||||
|
enum.create(op.get_bind(), checkfirst=True)
|
||||||
|
|
||||||
|
# ── users ─────────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"users",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("email", sa.String(255), nullable=False),
|
||||||
|
sa.Column("password_hash", sa.String(255), nullable=False),
|
||||||
|
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier"), nullable=False, server_default="free"),
|
||||||
|
sa.Column("stripe_customer_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("email"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_users_email", "users", ["email"])
|
||||||
|
|
||||||
|
# ── refresh_tokens ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"refresh_tokens",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("token_hash", sa.String(64), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("token_hash"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_refresh_tokens_user_id", "refresh_tokens", ["user_id"])
|
||||||
|
op.create_index("ix_refresh_tokens_token_hash", "refresh_tokens", ["token_hash"])
|
||||||
|
|
||||||
|
# ── subscriptions ─────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"subscriptions",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("stripe_subscription_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier"), nullable=False, server_default="free"),
|
||||||
|
sa.Column("status", sa.String(50), nullable=False, server_default="free"),
|
||||||
|
sa.Column("current_period_end", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("user_id"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
|
||||||
|
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
|
||||||
|
|
||||||
|
# ── storage_records ───────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"storage_records",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("table_name", sa.String(100), nullable=False),
|
||||||
|
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||||
|
sa.Column("checksum", sa.String(64), nullable=False),
|
||||||
|
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
|
||||||
|
|
||||||
|
# ── backup_metadata ───────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"backup_metadata",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||||
|
sa.Column("version", sa.Integer, nullable=False),
|
||||||
|
sa.Column("timestamp", sa.BigInteger, nullable=False),
|
||||||
|
sa.Column("checksum", sa.String(64), nullable=False),
|
||||||
|
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
|
||||||
|
|
||||||
|
# ── plugins ───────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugins",
|
||||||
|
sa.Column("id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("description", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
|
||||||
|
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||||
|
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
|
||||||
|
sa.Column("category", sa.String(100), nullable=False, server_default=""),
|
||||||
|
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status"), nullable=False, server_default="pending_review"),
|
||||||
|
sa.Column("s3_package_key", sa.String(500), nullable=True),
|
||||||
|
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
|
||||||
|
sa.Column("rejection_reason", sa.Text, nullable=True),
|
||||||
|
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── plugin_installations ──────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugin_installations",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
|
||||||
|
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
|
||||||
|
|
||||||
|
# ── plugin_reviews ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugin_reviews",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||||
|
sa.Column("decision", sa.Enum("approved", "rejected", name="review_decision"), nullable=False),
|
||||||
|
sa.Column("notes", sa.Text, nullable=True),
|
||||||
|
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
|
||||||
|
|
||||||
|
# ── revenue_events ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"revenue_events",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
|
||||||
|
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("revenue_events")
|
||||||
|
op.drop_table("plugin_reviews")
|
||||||
|
op.drop_table("plugin_installations")
|
||||||
|
op.drop_table("plugins")
|
||||||
|
op.drop_table("backup_metadata")
|
||||||
|
op.drop_table("storage_records")
|
||||||
|
op.drop_table("subscriptions")
|
||||||
|
op.drop_table("refresh_tokens")
|
||||||
|
op.drop_table("users")
|
||||||
|
|
||||||
|
op.execute("DROP TYPE IF EXISTS review_decision")
|
||||||
|
op.execute("DROP TYPE IF EXISTS plugin_status")
|
||||||
|
op.execute("DROP TYPE IF EXISTS billing_tier")
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
"""Auth middleware — JWT validation dependency.
|
"""Auth middleware — JWT validation dependency.
|
||||||
|
|
||||||
``get_current_user`` is the FastAPI dependency used by all protected routes.
|
``get_current_user`` is the FastAPI dependency used by all protected routes.
|
||||||
It decodes the Bearer JWT, validates signature and expiry, and returns a
|
It decodes the Bearer JWT (identity + expiry), then fetches the current tier
|
||||||
``UserProfile`` carrying ``id``, ``email``, and ``tier``.
|
from the ``subscriptions`` table so that tier changes take effect immediately
|
||||||
|
without requiring token re-issue.
|
||||||
|
|
||||||
Exempt routes (no JWT required):
|
Exempt routes (no JWT required):
|
||||||
- POST /api/v1/auth/register
|
- POST /api/v1/auth/register
|
||||||
@@ -15,8 +16,11 @@ from __future__ import annotations
|
|||||||
from fastapi import Depends, HTTPException, status
|
from fastapi import Depends, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
|
from app.db import get_session
|
||||||
from app.schemas import UserProfile
|
from app.schemas import UserProfile
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
@@ -24,12 +28,15 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
|||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
token: str = Depends(oauth2_scheme),
|
token: str = Depends(oauth2_scheme),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> UserProfile:
|
) -> UserProfile:
|
||||||
"""Validate a Bearer JWT and return the authenticated user.
|
"""Validate a Bearer JWT and return the authenticated user.
|
||||||
|
|
||||||
|
The JWT is used for identity and expiry only. The tier is fetched live
|
||||||
|
from the ``subscriptions`` table so that upgrades/downgrades take effect
|
||||||
|
immediately. Falls back to ``'free'`` when no subscription row exists.
|
||||||
|
|
||||||
Raises HTTP 401 on any invalid or expired token.
|
Raises HTTP 401 on any invalid or expired token.
|
||||||
The tier embedded in the JWT is used for feature-gating until Step 12
|
|
||||||
adds a live DB lookup.
|
|
||||||
"""
|
"""
|
||||||
credentials_exc = HTTPException(
|
credentials_exc = HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@@ -42,10 +49,17 @@ async def get_current_user(
|
|||||||
)
|
)
|
||||||
user_id: str | None = payload.get("sub")
|
user_id: str | None = payload.get("sub")
|
||||||
email: str | None = payload.get("email")
|
email: str | None = payload.get("email")
|
||||||
tier: str = payload.get("tier", "free")
|
|
||||||
if not user_id or not email:
|
if not user_id or not email:
|
||||||
raise credentials_exc
|
raise credentials_exc
|
||||||
except JWTError:
|
except JWTError:
|
||||||
raise credentials_exc
|
raise credentials_exc
|
||||||
|
|
||||||
|
# Live tier lookup — subscription row is the authoritative source.
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
tier: str = result.scalar_one_or_none() or "free"
|
||||||
|
|
||||||
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]
|
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]
|
||||||
|
|||||||
@@ -1,33 +1,36 @@
|
|||||||
"""Auth routes: register, login, refresh, me.
|
"""Auth routes: register, login, refresh, me.
|
||||||
|
|
||||||
Users and refresh tokens are kept in an in-memory dict until Step 12
|
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
||||||
migrates them to PostgreSQL.
|
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
||||||
|
SHA-256 hashes so plaintext never reaches the DB.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import RefreshToken, User
|
||||||
from app.schemas import AuthTokens, UserProfile
|
from app.schemas import AuthTokens, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
# ── In-memory stores (replaced by PostgreSQL in Step 12) ─────────────
|
|
||||||
_users: dict[str, dict[str, Any]] = {} # email → user record
|
|
||||||
_refresh_tokens: dict[str, str] = {} # plain token → user_id
|
|
||||||
|
|
||||||
|
|
||||||
# ── Internal helpers ─────────────────────────────────────────────────
|
# ── Internal helpers ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _hash_password(password: str) -> str:
|
def _hash_password(password: str) -> str:
|
||||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||||
|
|
||||||
@@ -36,30 +39,29 @@ def _verify_password(password: str, hashed: str) -> bool:
|
|||||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||||
|
|
||||||
|
|
||||||
def _make_tokens(user_id: str, email: str, tier: str) -> AuthTokens:
|
def _hash_token(plain_token: str) -> str:
|
||||||
|
"""SHA-256 of the plain refresh token string."""
|
||||||
|
return hashlib.sha256(plain_token.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
||||||
|
"""Return (signed JWT, expires_at_ms)."""
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
access_exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||||
access_payload = {
|
payload = {
|
||||||
"sub": user_id,
|
"sub": user_id,
|
||||||
"email": email,
|
"email": email,
|
||||||
"tier": tier,
|
"tier": tier,
|
||||||
"exp": access_exp,
|
"exp": exp,
|
||||||
"iat": now,
|
"iat": now,
|
||||||
}
|
}
|
||||||
access_token = jwt.encode(
|
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||||
access_payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
|
return token, exp * 1000 # ms for client
|
||||||
)
|
|
||||||
refresh_token = str(uuid.uuid4())
|
|
||||||
_refresh_tokens[refresh_token] = user_id
|
|
||||||
return AuthTokens(
|
|
||||||
access_token=access_token,
|
|
||||||
refresh_token=refresh_token,
|
|
||||||
expires_at=access_exp * 1000, # milliseconds for client
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Request bodies ────────────────────────────────────────────────────
|
# ── Request bodies ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class _RegisterRequest(BaseModel):
|
class _RegisterRequest(BaseModel):
|
||||||
email: str
|
email: str
|
||||||
password: str
|
password: str
|
||||||
@@ -76,40 +78,117 @@ class _RefreshRequest(BaseModel):
|
|||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────
|
# ── Routes ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
||||||
async def register(body: _RegisterRequest) -> AuthTokens:
|
async def register(
|
||||||
|
body: _RegisterRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
"""Create a new account and return JWT tokens."""
|
"""Create a new account and return JWT tokens."""
|
||||||
if body.email in _users:
|
existing = await db.execute(select(User).where(User.email == body.email))
|
||||||
|
if existing.scalar_one_or_none() is not None:
|
||||||
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
||||||
user_id = str(uuid.uuid4())
|
|
||||||
_users[body.email] = {
|
user = User(
|
||||||
"id": user_id,
|
id=str(uuid.uuid4()),
|
||||||
"email": body.email,
|
email=body.email,
|
||||||
"password_hash": _hash_password(body.password),
|
password_hash=_hash_password(body.password),
|
||||||
"tier": "free",
|
tier="free",
|
||||||
}
|
)
|
||||||
return _make_tokens(user_id, body.email, "free")
|
db.add(user)
|
||||||
|
await db.flush() # get user.id without committing
|
||||||
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
|
rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=AuthTokens)
|
@router.post("/login", response_model=AuthTokens)
|
||||||
async def login(body: _LoginRequest) -> AuthTokens:
|
async def login(
|
||||||
|
body: _LoginRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
"""Validate credentials and return JWT tokens."""
|
"""Validate credentials and return JWT tokens."""
|
||||||
user = _users.get(body.email)
|
result = await db.execute(select(User).where(User.email == body.email))
|
||||||
if not user or not _verify_password(body.password, user["password_hash"]):
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not _verify_password(body.password, user.password_hash):
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||||
return _make_tokens(user["id"], user["email"], user["tier"])
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
|
rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh", response_model=AuthTokens)
|
@router.post("/refresh", response_model=AuthTokens)
|
||||||
async def refresh(body: _RefreshRequest) -> AuthTokens:
|
async def refresh(
|
||||||
|
body: _RefreshRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
"""Rotate a refresh token and return a new token pair."""
|
"""Rotate a refresh token and return a new token pair."""
|
||||||
user_id = _refresh_tokens.pop(body.refresh_token, None)
|
token_hash = _hash_token(body.refresh_token)
|
||||||
if user_id is None:
|
result = await db.execute(
|
||||||
|
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||||
|
)
|
||||||
|
rt = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
||||||
user = next((u for u in _users.values() if u["id"] == user_id), None)
|
|
||||||
|
# Rotate: delete old token, issue new one.
|
||||||
|
await db.delete(rt)
|
||||||
|
|
||||||
|
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
||||||
|
user = user_result.scalar_one_or_none()
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
||||||
return _make_tokens(user["id"], user["email"], user["tier"])
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
new_rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=new_expires,
|
||||||
|
)
|
||||||
|
db.add(new_rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserProfile)
|
@router.get("/me", response_model=UserProfile)
|
||||||
|
|||||||
@@ -11,9 +11,11 @@ from typing import Any
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, Request, status
|
from fastapi import APIRouter, Depends, Header, Request, status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.billing.stripe_service import stripe_service
|
from app.billing.stripe_service import stripe_service
|
||||||
|
from app.db import get_session
|
||||||
from app.schemas import BillingTier, UserProfile
|
from app.schemas import BillingTier, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/billing", tags=["billing"])
|
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||||
@@ -44,6 +46,7 @@ async def create_checkout(
|
|||||||
async def stripe_webhook(
|
async def stripe_webhook(
|
||||||
request: Request,
|
request: Request,
|
||||||
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> dict[str, bool]:
|
) -> dict[str, bool]:
|
||||||
"""Handle Stripe webhook events.
|
"""Handle Stripe webhook events.
|
||||||
|
|
||||||
@@ -51,16 +54,17 @@ async def stripe_webhook(
|
|||||||
Returns 200 immediately when Stripe is not configured (local dev).
|
Returns 200 immediately when Stripe is not configured (local dev).
|
||||||
"""
|
"""
|
||||||
payload = await request.body()
|
payload = await request.body()
|
||||||
stripe_service.handle_webhook(payload, stripe_signature)
|
await stripe_service.handle_webhook(payload, stripe_signature, db)
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/subscription", response_model=dict)
|
@router.get("/subscription", response_model=dict)
|
||||||
async def get_subscription(
|
async def get_subscription(
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Return the current subscription info for the authenticated user."""
|
"""Return the current subscription info for the authenticated user."""
|
||||||
sub = stripe_service.get_subscription(current_user.id)
|
sub = await stripe_service.get_subscription(current_user.id, db)
|
||||||
if sub is None:
|
if sub is None:
|
||||||
return {
|
return {
|
||||||
"tier": current_user.tier,
|
"tier": current_user.tier,
|
||||||
@@ -74,7 +78,8 @@ async def get_subscription(
|
|||||||
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
|
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
|
||||||
async def cancel_subscription(
|
async def cancel_subscription(
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> dict[str, bool]:
|
) -> dict[str, bool]:
|
||||||
"""Cancel the active subscription."""
|
"""Cancel the active subscription."""
|
||||||
stripe_service.cancel_subscription(current_user.id)
|
await stripe_service.cancel_subscription(current_user.id, db)
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|||||||
@@ -1,17 +1,19 @@
|
|||||||
"""Stripe service: checkout sessions, webhook handling, subscription management.
|
"""Stripe service: checkout sessions, webhook handling, subscription management.
|
||||||
|
|
||||||
Subscriptions are stored in-memory until Step 12 migrates them to the
|
Subscription records are persisted in the PostgreSQL ``subscriptions`` table.
|
||||||
PostgreSQL ``subscriptions`` table. All Stripe calls are gracefully stubbed
|
All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not
|
||||||
when ``STRIPE_SECRET_KEY`` is not configured, enabling local development
|
configured, enabling local development without live credentials.
|
||||||
without live credentials.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import stripe as stripe_lib
|
import stripe as stripe_lib
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
|
|
||||||
@@ -24,15 +26,7 @@ TIER_PRICE_IDS: dict[str, str] = {
|
|||||||
|
|
||||||
|
|
||||||
class StripeService:
|
class StripeService:
|
||||||
"""Wraps all Stripe interactions and owns the in-memory subscription store.
|
"""Wraps all Stripe interactions and owns subscription persistence."""
|
||||||
|
|
||||||
Step 12 will replace ``_subscriptions`` with real PostgreSQL queries.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
# user_id → subscription record dict
|
|
||||||
# Replaced by the ``subscriptions`` table in Step 12.
|
|
||||||
self._subscriptions: dict[str, dict[str, Any]] = {}
|
|
||||||
|
|
||||||
# ── Internal helpers ────────────────────────────────────────────────
|
# ── Internal helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -84,7 +78,12 @@ class StripeService:
|
|||||||
)
|
)
|
||||||
return session.url
|
return session.url
|
||||||
|
|
||||||
def handle_webhook(self, payload: bytes, sig_header: str) -> None:
|
async def handle_webhook(
|
||||||
|
self,
|
||||||
|
payload: bytes,
|
||||||
|
sig_header: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
"""Process a Stripe webhook event.
|
"""Process a Stripe webhook event.
|
||||||
|
|
||||||
Verifies the signature, then dispatches on event type.
|
Verifies the signature, then dispatches on event type.
|
||||||
@@ -112,57 +111,82 @@ class StripeService:
|
|||||||
user_id = data.get("metadata", {}).get("user_id")
|
user_id = data.get("metadata", {}).get("user_id")
|
||||||
tier = data.get("metadata", {}).get("tier", "free")
|
tier = data.get("metadata", {}).get("tier", "free")
|
||||||
sub_id = data.get("subscription")
|
sub_id = data.get("subscription")
|
||||||
period_end = data.get("current_period_end")
|
period_end_ts = data.get("current_period_end")
|
||||||
|
period_end = (
|
||||||
|
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||||
|
if period_end_ts
|
||||||
|
else None
|
||||||
|
)
|
||||||
if user_id:
|
if user_id:
|
||||||
self._subscriptions[user_id] = {
|
await self._upsert_subscription(
|
||||||
"tier": tier,
|
db, user_id, sub_id, tier, "active", period_end
|
||||||
"stripe_subscription_id": sub_id,
|
)
|
||||||
"status": "active",
|
|
||||||
"current_period_end": period_end,
|
|
||||||
}
|
|
||||||
|
|
||||||
elif event_type == "customer.subscription.updated":
|
elif event_type == "customer.subscription.updated":
|
||||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, update tier
|
|
||||||
sub_id = data.get("id")
|
sub_id = data.get("id")
|
||||||
new_status = data.get("status")
|
new_status = data.get("status", "active")
|
||||||
period_end = data.get("current_period_end")
|
period_end_ts = data.get("current_period_end")
|
||||||
for record in self._subscriptions.values():
|
period_end = (
|
||||||
if record.get("stripe_subscription_id") == sub_id:
|
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||||
record["status"] = new_status
|
if period_end_ts
|
||||||
record["current_period_end"] = period_end
|
else None
|
||||||
break
|
)
|
||||||
|
if sub_id:
|
||||||
|
await self._update_subscription_by_stripe_id(
|
||||||
|
db, sub_id, status=new_status, current_period_end=period_end
|
||||||
|
)
|
||||||
|
|
||||||
elif event_type == "customer.subscription.deleted":
|
elif event_type == "customer.subscription.deleted":
|
||||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free
|
|
||||||
sub_id = data.get("id")
|
sub_id = data.get("id")
|
||||||
for user_id, record in self._subscriptions.items():
|
if sub_id:
|
||||||
if record.get("stripe_subscription_id") == sub_id:
|
await self._update_subscription_by_stripe_id(
|
||||||
self._subscriptions[user_id] = {
|
db, sub_id, tier="free", status="canceled"
|
||||||
**record,
|
)
|
||||||
"tier": "free",
|
|
||||||
"status": "canceled",
|
|
||||||
}
|
|
||||||
break
|
|
||||||
|
|
||||||
elif event_type == "invoice.payment_failed":
|
elif event_type == "invoice.payment_failed":
|
||||||
# TODO(Step12): flag subscription as past_due, notify user
|
|
||||||
sub_id = data.get("subscription")
|
sub_id = data.get("subscription")
|
||||||
for record in self._subscriptions.values():
|
if sub_id:
|
||||||
if record.get("stripe_subscription_id") == sub_id:
|
await self._update_subscription_by_stripe_id(
|
||||||
record["status"] = "past_due"
|
db, sub_id, status="past_due"
|
||||||
break
|
)
|
||||||
|
|
||||||
def get_subscription(self, user_id: str) -> dict[str, Any] | None:
|
await db.commit()
|
||||||
|
|
||||||
|
async def get_subscription(
|
||||||
|
self, user_id: str, db: AsyncSession
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
"""Return the subscription record for ``user_id``, or ``None`` if absent."""
|
"""Return the subscription record for ``user_id``, or ``None`` if absent."""
|
||||||
return self._subscriptions.get(user_id)
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
def cancel_subscription(self, user_id: str) -> None:
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"tier": sub.tier,
|
||||||
|
"stripe_subscription_id": sub.stripe_subscription_id,
|
||||||
|
"status": sub.status,
|
||||||
|
"current_period_end": (
|
||||||
|
int(sub.current_period_end.timestamp() * 1000)
|
||||||
|
if sub.current_period_end
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
|
||||||
"""Cancel the user's Stripe subscription and downgrade them to free.
|
"""Cancel the user's Stripe subscription and downgrade them to free.
|
||||||
|
|
||||||
Raises ``HTTP 404`` when no active subscription exists.
|
Raises ``HTTP 404`` when no active subscription exists.
|
||||||
"""
|
"""
|
||||||
sub = self._subscriptions.get(user_id)
|
from app.models import Subscription # noqa: PLC0415
|
||||||
if sub is None or not sub.get("stripe_subscription_id"):
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None or not sub.stripe_subscription_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="No active subscription found",
|
detail="No active subscription found",
|
||||||
@@ -170,13 +194,62 @@ class StripeService:
|
|||||||
|
|
||||||
if self._configured():
|
if self._configured():
|
||||||
s = self._client()
|
s = self._client()
|
||||||
s.Subscription.cancel(sub["stripe_subscription_id"])
|
s.Subscription.cancel(sub.stripe_subscription_id)
|
||||||
|
|
||||||
self._subscriptions[user_id] = {
|
sub.tier = "free"
|
||||||
**sub,
|
sub.status = "canceled"
|
||||||
"tier": "free",
|
await db.commit()
|
||||||
"status": "canceled",
|
|
||||||
}
|
# ── Private DB helpers ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _upsert_subscription(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
stripe_subscription_id: str | None,
|
||||||
|
tier: str,
|
||||||
|
sub_status: str,
|
||||||
|
current_period_end: datetime | None,
|
||||||
|
) -> None:
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
sub = Subscription(user_id=user_id)
|
||||||
|
db.add(sub)
|
||||||
|
sub.stripe_subscription_id = stripe_subscription_id
|
||||||
|
sub.tier = tier
|
||||||
|
sub.status = sub_status
|
||||||
|
sub.current_period_end = current_period_end
|
||||||
|
|
||||||
|
async def _update_subscription_by_stripe_id(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
stripe_subscription_id: str,
|
||||||
|
*,
|
||||||
|
tier: str | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
current_period_end: datetime | None = None,
|
||||||
|
) -> None:
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(
|
||||||
|
Subscription.stripe_subscription_id == stripe_subscription_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
return
|
||||||
|
if tier is not None:
|
||||||
|
sub.tier = tier
|
||||||
|
if status is not None:
|
||||||
|
sub.status = status
|
||||||
|
if current_period_end is not None:
|
||||||
|
sub.current_period_end = current_period_end
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton shared across the app.
|
# Module-level singleton shared across the app.
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
"""Tier manager: feature matrix and quota enforcement.
|
"""Tier manager: feature matrix and quota enforcement.
|
||||||
|
|
||||||
``TierManager`` is the single source of truth for what each billing tier
|
``TierManager`` is the single source of truth for what each billing tier
|
||||||
allows. ``get_tier`` reads from the ``StripeService`` in-memory store until
|
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
|
||||||
Step 12 replaces it with a live PostgreSQL lookup.
|
Quota-enforcement helpers take ``tier`` directly — the caller already has it
|
||||||
|
from ``current_user.tier`` (provided by ``get_current_user``).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -10,6 +11,8 @@ from __future__ import annotations
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import HTTPException, status
|
from fastapi import HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.schemas import BillingTier
|
from app.schemas import BillingTier
|
||||||
|
|
||||||
@@ -67,55 +70,42 @@ RATE_LIMITS: dict[str, int] = {
|
|||||||
|
|
||||||
|
|
||||||
class TierManager:
|
class TierManager:
|
||||||
"""Centralises tier feature-gating, rate-limit lookups, and quota checks.
|
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
|
||||||
|
|
||||||
``get_tier`` consults the ``StripeService`` singleton. Step 12 will
|
|
||||||
replace that with a PostgreSQL query so that the tier is always fresh.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# ── Tier lookup ─────────────────────────────────────────────────────
|
# ── Tier lookup ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
def get_tier(self, user_id: str) -> BillingTier:
|
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||||
"""Return the current billing tier for ``user_id``.
|
"""Return the current billing tier for ``user_id`` from the DB.
|
||||||
|
|
||||||
Falls back to ``'free'`` when no subscription record exists.
|
Falls back to ``'free'`` when no subscription row exists.
|
||||||
Step 12 will replace this with a live DB lookup.
|
|
||||||
"""
|
"""
|
||||||
# Import here to avoid circular imports at module load time.
|
from app.models import Subscription # noqa: PLC0415
|
||||||
from app.billing.stripe_service import stripe_service # noqa: PLC0415
|
|
||||||
|
|
||||||
sub = stripe_service.get_subscription(user_id)
|
result = await db.execute(
|
||||||
if sub is None:
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
return "free"
|
)
|
||||||
tier = sub.get("tier", "free")
|
tier: str | None = result.scalar_one_or_none()
|
||||||
# Validate against known tiers; unknown values fall back to free.
|
if tier is None or tier not in FEATURES:
|
||||||
if tier not in FEATURES:
|
|
||||||
return "free"
|
return "free"
|
||||||
return tier # type: ignore[return-value]
|
return tier # type: ignore[return-value]
|
||||||
|
|
||||||
# ── Feature access ───────────────────────────────────────────────────
|
# ── Feature access ───────────────────────────────────────────────────
|
||||||
|
|
||||||
def check_feature(self, user_id: str, feature: str) -> bool:
|
def check_feature(self, tier: BillingTier, feature: str) -> bool:
|
||||||
"""Return ``True`` if ``user_id``'s current tier has ``feature`` enabled.
|
"""Return ``True`` if ``tier`` has ``feature`` enabled.
|
||||||
|
|
||||||
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
|
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
|
||||||
"""
|
"""
|
||||||
tier = self.get_tier(user_id)
|
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||||
value = FEATURES[tier].get(feature)
|
|
||||||
if value is None:
|
if value is None:
|
||||||
return False
|
return False
|
||||||
if isinstance(value, bool):
|
if isinstance(value, bool):
|
||||||
return value
|
return value
|
||||||
# Numeric: -1 means unlimited (enabled), 0 means disabled.
|
|
||||||
return value != 0
|
return value != 0
|
||||||
|
|
||||||
def require_feature(self, user_id: str, feature: str, tier_name: str = "") -> None:
|
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
|
||||||
"""Raise ``HTTP 403`` if ``user_id`` does not have ``feature``.
|
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
|
||||||
|
if not self.check_feature(tier, feature):
|
||||||
``tier_name`` is used in the error message to tell users which tier
|
|
||||||
they need to upgrade to.
|
|
||||||
"""
|
|
||||||
if not self.check_feature(user_id, feature):
|
|
||||||
detail = (
|
detail = (
|
||||||
f"Feature '{feature}' requires {tier_name} tier or above."
|
f"Feature '{feature}' requires {tier_name} tier or above."
|
||||||
if tier_name
|
if tier_name
|
||||||
@@ -131,39 +121,17 @@ class TierManager:
|
|||||||
|
|
||||||
# ── Storage quota ────────────────────────────────────────────────────
|
# ── Storage quota ────────────────────────────────────────────────────
|
||||||
|
|
||||||
def check_quota(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
current_bytes: int = 0,
|
|
||||||
additional_bytes: int = 0,
|
|
||||||
) -> bool:
|
|
||||||
"""Return ``True`` if ``user_id`` can store ``additional_bytes`` more data.
|
|
||||||
|
|
||||||
``current_bytes`` is the user's current storage usage (from the
|
|
||||||
caller's record-keeping). Step 12 will remove these parameters and
|
|
||||||
query the DB directly.
|
|
||||||
|
|
||||||
Returns ``False`` if the tier has no storage allocation at all
|
|
||||||
(free tier), or if ``current_bytes + additional_bytes`` would exceed
|
|
||||||
the tier's ``cloud_storage_gb`` limit.
|
|
||||||
"""
|
|
||||||
tier = self.get_tier(user_id)
|
|
||||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
|
||||||
if limit_gb == 0:
|
|
||||||
return False # tier has no storage
|
|
||||||
if limit_gb == -1:
|
|
||||||
return True # unlimited
|
|
||||||
limit_bytes = limit_gb * 1024 ** 3
|
|
||||||
return current_bytes + additional_bytes <= limit_bytes
|
|
||||||
|
|
||||||
def enforce_quota(
|
def enforce_quota(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
tier: BillingTier,
|
||||||
current_bytes: int = 0,
|
current_bytes: int = 0,
|
||||||
additional_bytes: int = 0,
|
additional_bytes: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Raise ``HTTP 402`` if ``user_id`` would exceed their storage quota."""
|
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
|
||||||
tier = self.get_tier(user_id)
|
|
||||||
|
``tier`` is the caller's current tier (from ``current_user.tier``).
|
||||||
|
``current_bytes`` is the total bytes already stored (queried by caller).
|
||||||
|
"""
|
||||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
if limit_gb == 0:
|
if limit_gb == 0:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -181,12 +149,11 @@ class TierManager:
|
|||||||
|
|
||||||
def enforce_backup_quota(
|
def enforce_backup_quota(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
tier: BillingTier,
|
||||||
current_bytes: int = 0,
|
current_bytes: int = 0,
|
||||||
additional_bytes: int = 0,
|
additional_bytes: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Raise ``HTTP 402`` if ``user_id`` would exceed their backup quota."""
|
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
|
||||||
tier = self.get_tier(user_id)
|
|
||||||
limit_gb: int = FEATURES[tier]["backup_gb"]
|
limit_gb: int = FEATURES[tier]["backup_gb"]
|
||||||
if limit_gb == 0:
|
if limit_gb == 0:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -202,6 +169,21 @@ class TierManager:
|
|||||||
detail=f"Backup quota exceeded for tier '{tier}'",
|
detail=f"Backup quota exceeded for tier '{tier}'",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> bool:
|
||||||
|
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
|
||||||
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
return False
|
||||||
|
if limit_gb == -1:
|
||||||
|
return True
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
return current_bytes + additional_bytes <= limit_bytes
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton shared across the app.
|
# Module-level singleton shared across the app.
|
||||||
tier_manager = TierManager()
|
tier_manager = TierManager()
|
||||||
|
|||||||
40
app/db.py
Normal file
40
app/db.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""Database engine, session factory, and base model.
|
||||||
|
|
||||||
|
All app code uses the async SQLAlchemy API. Alembic migrations use the
|
||||||
|
synchronous psycopg2 URL for the CLI (see alembic/env.py).
|
||||||
|
|
||||||
|
Usage in routes:
|
||||||
|
from app.db import get_session
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
async def my_route(db: AsyncSession = Depends(get_session)):
|
||||||
|
result = await db.execute(select(User).where(User.email == email))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
settings.DATABASE_URL,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
echo=settings.ENV == "dev",
|
||||||
|
)
|
||||||
|
|
||||||
|
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
"""Shared declarative base for all ORM models."""
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""FastAPI dependency that yields an async DB session per request."""
|
||||||
|
async with async_session() as session:
|
||||||
|
yield session
|
||||||
@@ -16,7 +16,9 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Shutdown: nothing to clean up for now
|
# Shutdown: dispose SQLAlchemy connection pool
|
||||||
|
from app.db import engine
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
|
|||||||
269
app/models.py
Normal file
269
app/models.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
"""SQLAlchemy ORM models for all persistent tables.
|
||||||
|
|
||||||
|
Only auth, billing, storage metadata, and marketplace data live here.
|
||||||
|
User content (notes, tasks, etc.) is NEVER persisted server-side —
|
||||||
|
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
|
||||||
|
|
||||||
|
Table inventory:
|
||||||
|
users — account credentials + tier
|
||||||
|
refresh_tokens — hashed refresh token store
|
||||||
|
subscriptions — Stripe subscription records
|
||||||
|
storage_records — S3 blob metadata (no plaintext)
|
||||||
|
backup_metadata — encrypted backup manifests
|
||||||
|
plugins — marketplace plugin catalog
|
||||||
|
plugin_installations — per-user install records
|
||||||
|
plugin_reviews — admin review decisions
|
||||||
|
revenue_events — Stripe Connect 70/30 split ledger
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
BigInteger,
|
||||||
|
Boolean,
|
||||||
|
DateTime,
|
||||||
|
Enum,
|
||||||
|
Float,
|
||||||
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
|
func,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.db import Base
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _uuid() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def _now() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Enum types ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
||||||
|
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
|
||||||
|
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Models ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||||
|
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
|
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_tokens: Mapped[list[RefreshToken]] = relationship(
|
||||||
|
back_populates="user", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
subscription: Mapped[Subscription | None] = relationship(
|
||||||
|
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshToken(Base):
|
||||||
|
__tablename__ = "refresh_tokens"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||||
|
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
||||||
|
|
||||||
|
|
||||||
|
class Subscription(Base):
|
||||||
|
__tablename__ = "subscriptions"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, unique=True, index=True
|
||||||
|
)
|
||||||
|
stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
|
||||||
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
|
status: Mapped[str] = mapped_column(String(50), nullable=False, default="free")
|
||||||
|
current_period_end: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
user: Mapped[User] = relationship(back_populates="subscription")
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRecord(Base):
|
||||||
|
__tablename__ = "storage_records"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BackupMetadata(Base):
|
||||||
|
__tablename__ = "backup_metadata"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||||
|
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin(Base):
|
||||||
|
__tablename__ = "plugins"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
|
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
|
||||||
|
# nullable until developer account system is built
|
||||||
|
author_id: Mapped[str | None] = mapped_column(
|
||||||
|
UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
|
||||||
|
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
|
||||||
|
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
|
||||||
|
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
|
||||||
|
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
|
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||||
|
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
submitted_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
installations: Mapped[list[PluginInstallation]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
reviews: Mapped[list[PluginReview]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
revenue_events: Mapped[list[RevenueEvent]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInstallation(Base):
|
||||||
|
__tablename__ = "plugin_installations"
|
||||||
|
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
installed_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="installations")
|
||||||
|
|
||||||
|
|
||||||
|
class PluginReview(Base):
|
||||||
|
__tablename__ = "plugin_reviews"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
reviewer_id: Mapped[str | None] = mapped_column(
|
||||||
|
UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
|
||||||
|
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
reviewed_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
|
||||||
|
|
||||||
|
|
||||||
|
class RevenueEvent(Base):
|
||||||
|
__tablename__ = "revenue_events"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
|
||||||
Reference in New Issue
Block a user