Compare commits
4 Commits
4c4df7335a
...
5d485b3665
| Author | SHA1 | Date | |
|---|---|---|---|
| 5d485b3665 | |||
| 9787befd4a | |||
| 8f7bc25611 | |||
| 3e07fff958 |
@@ -331,14 +331,14 @@ adiuva-api/
|
||||
### Step 9 — Middleware
|
||||
|
||||
#### 9a — Auth middleware
|
||||
- [ ] `app/api/middleware/auth.py`:
|
||||
- [x] `app/api/middleware/auth.py`:
|
||||
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
|
||||
- Validates JWT signature, expiry, extracts `user_id` and `tier`
|
||||
- Raises `401` on invalid/expired token
|
||||
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
||||
|
||||
#### 9b — Rate limiter
|
||||
- [ ] `app/api/middleware/rate_limit.py`:
|
||||
- [x] `app/api/middleware/rate_limit.py`:
|
||||
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
|
||||
- Tier-based limits:
|
||||
- Free: 20 req/min
|
||||
@@ -348,7 +348,7 @@ adiuva-api/
|
||||
- Custom 429 response with `Retry-After` header
|
||||
|
||||
#### 9c — Sanitizer
|
||||
- [ ] `app/api/middleware/sanitizer.py`:
|
||||
- [x] `app/api/middleware/sanitizer.py`:
|
||||
- Response middleware that scans response bodies
|
||||
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
|
||||
- Pattern-based detection + exact match against known prompt fingerprints
|
||||
@@ -356,33 +356,33 @@ adiuva-api/
|
||||
|
||||
- **Outcome:** Secure, rate-limited API with prompt IP protection.
|
||||
|
||||
### Step 10 — Plugin Marketplace
|
||||
- [ ] `app/marketplace/plugin_registry.py`:
|
||||
### Step 10 — Plugin Marketplace ✅
|
||||
- [x] `app/marketplace/plugin_registry.py`:
|
||||
- `PluginRegistry`:
|
||||
- `async list_plugins(category, query, page, sort) -> PluginListResponse`
|
||||
- `async get_plugin(plugin_id) -> PluginManifest | None`
|
||||
- `async submit_plugin(manifest: PluginManifest, package_s3_key: str) -> str` — returns plugin_id, sets status = 'pending_review'
|
||||
- `async approve_plugin(plugin_id) -> None` — admin only, sets status = 'approved'
|
||||
- `async reject_plugin(plugin_id, reason: str) -> None`
|
||||
- [ ] `app/marketplace/plugin_review.py`:
|
||||
- [x] `app/marketplace/plugin_review.py`:
|
||||
- `ReviewQueue`:
|
||||
- `async get_pending() -> list[dict]`
|
||||
- `async submit_review(plugin_id, reviewer_id, decision, notes) -> None`
|
||||
- Security checklist enforced before approval: manifest schema valid, permissions are from allowed set, no binary blobs in manifest
|
||||
- [ ] `app/marketplace/revenue_share.py`:
|
||||
- [x] `app/marketplace/revenue_share.py`:
|
||||
- `RevenueShare`:
|
||||
- `async record_install(plugin_id, user_id, amount_cents) -> None`
|
||||
- `async payout_developer(plugin_id, period) -> None` — Stripe Connect transfer: 70% to developer
|
||||
- `async get_earnings(developer_id, period) -> dict`
|
||||
- **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split.
|
||||
|
||||
### Step 11 — Billing & Tier management
|
||||
- [ ] `app/billing/stripe_service.py`:
|
||||
### Step 11 — Billing & Tier management ✅
|
||||
- [x] `app/billing/stripe_service.py`:
|
||||
- `create_checkout_session(user_id, tier) -> str`
|
||||
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
|
||||
- `get_subscription(user_id) -> dict | None`
|
||||
- `cancel_subscription(user_id) -> None`
|
||||
- [ ] `app/billing/tier_manager.py`:
|
||||
- [x] `app/billing/tier_manager.py`:
|
||||
- `TierManager`:
|
||||
- Feature matrix:
|
||||
```python
|
||||
@@ -433,6 +433,9 @@ adiuva-api/
|
||||
- `check_feature(user_id, feature) -> bool`
|
||||
- `get_rate_limit(tier) -> int`
|
||||
- `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit
|
||||
- [x] `app/billing/__init__.py`: exports `stripe_service` and `tier_manager` singletons
|
||||
- [x] `app/api/routes/billing.py`: refactored to delegate to `StripeService`
|
||||
- [x] `app/api/routes/storage.py` and `backup.py`: `_check_quota` now delegates to `tier_manager.enforce_quota` / `enforce_backup_quota`
|
||||
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
|
||||
|
||||
### Step 12 — Database (auth/billing/marketplace only)
|
||||
|
||||
47
alembic.ini
Normal file
47
alembic.ini
Normal file
@@ -0,0 +1,47 @@
|
||||
# Alembic configuration file.
|
||||
# The async app uses postgresql+asyncpg:// at runtime.
|
||||
# Alembic CLI uses the sync psycopg2 URL set in env.py (reads from DATABASE_URL env var).
|
||||
|
||||
[alembic]
|
||||
script_location = alembic
|
||||
prepend_sys_path = .
|
||||
version_path_separator = os
|
||||
|
||||
# sqlalchemy.url is overridden in alembic/env.py — leave as placeholder.
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
93
alembic/env.py
Normal file
93
alembic/env.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Alembic migration environment — async-compatible.
|
||||
|
||||
At runtime the app uses ``postgresql+asyncpg://``. Alembic's CLI is
|
||||
synchronous, so we derive a *sync* psycopg2 URL from the same DATABASE_URL
|
||||
env var by replacing the driver prefix.
|
||||
|
||||
Run migrations with:
|
||||
alembic upgrade head
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
# Alembic Config object (gives access to alembic.ini values).
|
||||
config = context.config
|
||||
|
||||
# Set up Python logging from alembic.ini.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Import the Base so that Alembic can detect model changes for --autogenerate.
|
||||
from app.models import Base # noqa: E402
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def _sync_url(async_url: str) -> str:
|
||||
"""Convert an asyncpg URL to a psycopg2 URL for Alembic CLI."""
|
||||
return re.sub(r"postgresql\+asyncpg", "postgresql+psycopg2", async_url)
|
||||
|
||||
|
||||
def _get_url() -> str:
|
||||
db_url = os.environ.get("DATABASE_URL", "")
|
||||
if not db_url:
|
||||
# Fall back to settings if env var not set directly.
|
||||
from app.config.settings import settings # noqa: PLC0415
|
||||
db_url = settings.DATABASE_URL
|
||||
return _sync_url(db_url)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Emit SQL without a live DB connection."""
|
||||
url = _get_url()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
compare_type=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection): # type: ignore[no-untyped-def]
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
compare_type=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_migrations_online_async() -> None:
|
||||
"""Run migrations against a live DB using the async engine."""
|
||||
async_url = os.environ.get("DATABASE_URL", "")
|
||||
if not async_url:
|
||||
from app.config.settings import settings # noqa: PLC0415
|
||||
async_url = settings.DATABASE_URL
|
||||
|
||||
connectable = create_async_engine(async_url, poolclass=pool.NullPool)
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
asyncio.run(run_migrations_online_async())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
28
alembic/script.py.mako
Normal file
28
alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
202
alembic/versions/001_initial_schema.py
Normal file
202
alembic/versions/001_initial_schema.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Initial schema: users, refresh_tokens, subscriptions, storage_records,
|
||||
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
|
||||
|
||||
Revision ID: 001
|
||||
Revises:
|
||||
Create Date: 2026-03-02
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "001"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── Enum types ────────────────────────────────────────────────────────
|
||||
billing_tier = postgresql.ENUM(
|
||||
"free", "pro", "power", "team", name="billing_tier", create_type=False
|
||||
)
|
||||
plugin_status = postgresql.ENUM(
|
||||
"pending_review", "approved", "rejected", name="plugin_status", create_type=False
|
||||
)
|
||||
review_decision = postgresql.ENUM(
|
||||
"approved", "rejected", name="review_decision", create_type=False
|
||||
)
|
||||
for enum in (billing_tier, plugin_status, review_decision):
|
||||
enum.create(op.get_bind(), checkfirst=True)
|
||||
|
||||
# ── users ─────────────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"users",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("email", sa.String(255), nullable=False),
|
||||
sa.Column("password_hash", sa.String(255), nullable=False),
|
||||
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier"), nullable=False, server_default="free"),
|
||||
sa.Column("stripe_customer_id", sa.String(255), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("email"),
|
||||
)
|
||||
op.create_index("ix_users_email", "users", ["email"])
|
||||
|
||||
# ── refresh_tokens ────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"refresh_tokens",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("token_hash", sa.String(64), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.UniqueConstraint("token_hash"),
|
||||
)
|
||||
op.create_index("ix_refresh_tokens_user_id", "refresh_tokens", ["user_id"])
|
||||
op.create_index("ix_refresh_tokens_token_hash", "refresh_tokens", ["token_hash"])
|
||||
|
||||
# ── subscriptions ─────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"subscriptions",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("stripe_subscription_id", sa.String(255), nullable=True),
|
||||
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier"), nullable=False, server_default="free"),
|
||||
sa.Column("status", sa.String(50), nullable=False, server_default="free"),
|
||||
sa.Column("current_period_end", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.UniqueConstraint("user_id"),
|
||||
)
|
||||
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
|
||||
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
|
||||
|
||||
# ── storage_records ───────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"storage_records",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("table_name", sa.String(100), nullable=False),
|
||||
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||
sa.Column("checksum", sa.String(64), nullable=False),
|
||||
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
|
||||
|
||||
# ── backup_metadata ───────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"backup_metadata",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||
sa.Column("version", sa.Integer, nullable=False),
|
||||
sa.Column("timestamp", sa.BigInteger, nullable=False),
|
||||
sa.Column("checksum", sa.String(64), nullable=False),
|
||||
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
|
||||
|
||||
# ── plugins ───────────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"plugins",
|
||||
sa.Column("id", sa.String(255), nullable=False),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("description", sa.Text, nullable=False, server_default=""),
|
||||
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
|
||||
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
|
||||
sa.Column("category", sa.String(100), nullable=False, server_default=""),
|
||||
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
|
||||
sa.Column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status"), nullable=False, server_default="pending_review"),
|
||||
sa.Column("s3_package_key", sa.String(500), nullable=True),
|
||||
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
|
||||
sa.Column("rejection_reason", sa.Text, nullable=True),
|
||||
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
|
||||
)
|
||||
|
||||
# ── plugin_installations ──────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"plugin_installations",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
|
||||
)
|
||||
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
|
||||
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
|
||||
|
||||
# ── plugin_reviews ────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"plugin_reviews",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||
sa.Column("decision", sa.Enum("approved", "rejected", name="review_decision"), nullable=False),
|
||||
sa.Column("notes", sa.Text, nullable=True),
|
||||
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
|
||||
)
|
||||
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
|
||||
|
||||
# ── revenue_events ────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"revenue_events",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
|
||||
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
|
||||
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("revenue_events")
|
||||
op.drop_table("plugin_reviews")
|
||||
op.drop_table("plugin_installations")
|
||||
op.drop_table("plugins")
|
||||
op.drop_table("backup_metadata")
|
||||
op.drop_table("storage_records")
|
||||
op.drop_table("subscriptions")
|
||||
op.drop_table("refresh_tokens")
|
||||
op.drop_table("users")
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS review_decision")
|
||||
op.execute("DROP TYPE IF EXISTS plugin_status")
|
||||
op.execute("DROP TYPE IF EXISTS billing_tier")
|
||||
@@ -1,46 +1,14 @@
|
||||
"""Shared FastAPI dependencies.
|
||||
|
||||
``get_current_user`` decodes the Bearer JWT and returns a ``UserProfile``.
|
||||
Step 9 will layer rate-limiting and sanitization middleware on top of this.
|
||||
Step 12 will add a DB look-up to fetch the live tier from PostgreSQL.
|
||||
``get_current_user`` and ``oauth2_scheme`` live in ``app.api.middleware.auth``
|
||||
(the canonical location per Step 9). This module re-exports them so that all
|
||||
existing route imports (``from app.api.deps import get_current_user``) continue
|
||||
to work without modification.
|
||||
|
||||
Step 12 will update ``get_current_user`` to fetch the live tier from PostgreSQL
|
||||
instead of reading it from the JWT payload.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from app.api.middleware.auth import get_current_user, oauth2_scheme # noqa: F401
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.schemas import BillingTier, UserProfile
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
) -> UserProfile:
|
||||
"""Validate a Bearer JWT and return the authenticated user.
|
||||
|
||||
Raises ``HTTP 401`` on any invalid or expired token.
|
||||
The tier embedded in the JWT is used for feature-gating until Step 12
|
||||
adds a live DB lookup.
|
||||
"""
|
||||
credentials_exc = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
email: str | None = payload.get("email")
|
||||
tier: str = payload.get("tier", "free")
|
||||
if not user_id or not email:
|
||||
raise credentials_exc
|
||||
except JWTError:
|
||||
raise credentials_exc
|
||||
|
||||
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]
|
||||
__all__ = ["get_current_user", "oauth2_scheme"]
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
"""API middleware package.
|
||||
|
||||
Exports the three middleware components introduced in Step 9:
|
||||
- Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme``
|
||||
- Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter)
|
||||
- Sanitizer: ``SanitizerMiddleware``
|
||||
"""
|
||||
|
||||
from app.api.middleware.auth import get_current_user, oauth2_scheme
|
||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter
|
||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||
|
||||
__all__ = [
|
||||
"get_current_user",
|
||||
"oauth2_scheme",
|
||||
"TierRateLimitMiddleware",
|
||||
"limiter",
|
||||
"SanitizerMiddleware",
|
||||
]
|
||||
|
||||
65
app/api/middleware/auth.py
Normal file
65
app/api/middleware/auth.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Auth middleware — JWT validation dependency.
|
||||
|
||||
``get_current_user`` is the FastAPI dependency used by all protected routes.
|
||||
It decodes the Bearer JWT (identity + expiry), then fetches the current tier
|
||||
from the ``subscriptions`` table so that tier changes take effect immediately
|
||||
without requiring token re-issue.
|
||||
|
||||
Exempt routes (no JWT required):
|
||||
- POST /api/v1/auth/register
|
||||
- POST /api/v1/auth/login
|
||||
- POST /api/v1/billing/webhook
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.db import get_session
|
||||
from app.schemas import UserProfile
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Validate a Bearer JWT and return the authenticated user.
|
||||
|
||||
The JWT is used for identity and expiry only. The tier is fetched live
|
||||
from the ``subscriptions`` table so that upgrades/downgrades take effect
|
||||
immediately. Falls back to ``'free'`` when no subscription row exists.
|
||||
|
||||
Raises HTTP 401 on any invalid or expired token.
|
||||
"""
|
||||
credentials_exc = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
email: str | None = payload.get("email")
|
||||
if not user_id or not email:
|
||||
raise credentials_exc
|
||||
except JWTError:
|
||||
raise credentials_exc
|
||||
|
||||
# Live tier lookup — subscription row is the authoritative source.
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
tier: str = result.scalar_one_or_none() or "free"
|
||||
|
||||
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]
|
||||
129
app/api/middleware/rate_limit.py
Normal file
129
app/api/middleware/rate_limit.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Tier-aware rate limiting middleware.
|
||||
|
||||
Uses a per-user sliding-window counter (in-process, no Redis required).
|
||||
The ``slowapi`` Limiter is also exported for optional route-level decoration.
|
||||
|
||||
Limits (requests per minute):
|
||||
- free: 20
|
||||
- pro: 60
|
||||
- power: 120
|
||||
- team: 200
|
||||
|
||||
Exempt paths bypass the limiter entirely:
|
||||
- POST /api/v1/auth/register
|
||||
- POST /api/v1/auth/login
|
||||
- POST /api/v1/billing/webhook
|
||||
- GET /api/v1/health
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import Request, Response
|
||||
from jose import JWTError, jwt
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
_TIER_LIMITS: dict[str, int] = {
|
||||
"free": 20,
|
||||
"pro": 60,
|
||||
"power": 120,
|
||||
"team": 200,
|
||||
}
|
||||
|
||||
_EXEMPT_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/billing/webhook",
|
||||
"/api/v1/health",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_user_id_from_jwt(request: Request) -> str:
|
||||
"""Key function for the slowapi Limiter: returns JWT sub or remote IP."""
|
||||
auth = request.headers.get("Authorization", "")
|
||||
token = auth.removeprefix("Bearer ").strip()
|
||||
if not token:
|
||||
return get_remote_address(request)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
return payload.get("sub") or get_remote_address(request)
|
||||
except JWTError:
|
||||
return get_remote_address(request)
|
||||
|
||||
|
||||
# Exported Limiter instance — available for optional route-level decoration.
|
||||
limiter = Limiter(key_func=_get_user_id_from_jwt)
|
||||
|
||||
|
||||
class TierRateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Sliding-window rate limiter applied globally across all non-exempt routes.
|
||||
|
||||
Each authenticated user gets their own 60-second window sized by tier.
|
||||
Unauthenticated requests pass through (the auth dependency will reject them
|
||||
with 401 before the route handler runs).
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
# user_id → list of request timestamps (float, seconds since epoch)
|
||||
self._window: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
||||
if request.url.path in _EXEMPT_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
# Extract JWT claims — if no valid token, pass through for auth dep to handle.
|
||||
auth = request.headers.get("Authorization", "")
|
||||
token = auth.removeprefix("Bearer ").strip()
|
||||
if not token:
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
user_id: str = payload.get("sub") or get_remote_address(request)
|
||||
tier: str = payload.get("tier", "free")
|
||||
except JWTError:
|
||||
return await call_next(request)
|
||||
|
||||
limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"])
|
||||
now = time.monotonic()
|
||||
window_start = now - 60.0
|
||||
|
||||
# Slide the window: discard timestamps older than 60 seconds.
|
||||
timestamps = [t for t in self._window[user_id] if t > window_start]
|
||||
|
||||
if len(timestamps) >= limit:
|
||||
retry_after = max(1, int(60 - (now - min(timestamps))))
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"detail": (
|
||||
f"Rate limit exceeded ({limit} req/min for {tier} tier). "
|
||||
f"Retry in {retry_after}s."
|
||||
)
|
||||
}
|
||||
),
|
||||
status_code=429,
|
||||
headers={
|
||||
"Retry-After": str(retry_after),
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
timestamps.append(now)
|
||||
self._window[user_id] = timestamps
|
||||
return await call_next(request)
|
||||
139
app/api/middleware/sanitizer.py
Normal file
139
app/api/middleware/sanitizer.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Response sanitizer middleware.
|
||||
|
||||
Scans JSON responses from the /api/v1/chat endpoint and strips any fragments
|
||||
that could reveal server-side prompt IP:
|
||||
- System prompt openers ("You are a/an/the …")
|
||||
- Agent routing metadata ("Available agents:", "intent classifier", …)
|
||||
- LangChain tool schema fragments (``"type": "function"``)
|
||||
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
||||
- Exact-match known prompt fingerprints
|
||||
|
||||
Binary responses (storage blobs, backup data) are never touched — the
|
||||
middleware only activates for paths under /api/v1/chat.
|
||||
|
||||
Any sanitisation event is logged as a WARNING with the request path and the
|
||||
names of the fields that were modified.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detection patterns — order matters: fingerprints checked first (exact),
|
||||
# then compiled regexes.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FINGERPRINTS: tuple[str, ...] = (
|
||||
"You are an intent classifier",
|
||||
"Respond with just the agent name",
|
||||
"Summarize these agent results",
|
||||
"Available agents:",
|
||||
"route to:",
|
||||
)
|
||||
|
||||
_PATTERNS: tuple[re.Pattern[str], ...] = (
|
||||
re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"Available agents\s*:", re.IGNORECASE),
|
||||
re.compile(r"\bintent classifier\b", re.IGNORECASE),
|
||||
re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema
|
||||
re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE),
|
||||
re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers
|
||||
re.compile(r"route\s+to\s*:", re.IGNORECASE),
|
||||
re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE),
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_text(text: str) -> tuple[str, bool]:
|
||||
"""Scan *text* for prompt fragments and replace matches with ``[REDACTED]``.
|
||||
|
||||
Returns ``(cleaned_text, was_changed)``.
|
||||
"""
|
||||
# Fingerprint check — if any exact phrase is present, redact the whole string.
|
||||
for fp in _FINGERPRINTS:
|
||||
if fp in text:
|
||||
return "[REDACTED]", True
|
||||
|
||||
changed = False
|
||||
for pattern in _PATTERNS:
|
||||
new_text, n = pattern.subn("[REDACTED]", text)
|
||||
if n:
|
||||
text = new_text
|
||||
changed = True
|
||||
|
||||
return text, changed
|
||||
|
||||
|
||||
class SanitizerMiddleware(BaseHTTPMiddleware):
|
||||
"""Strip prompt IP from /api/v1/chat JSON responses."""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
||||
response: Response = await call_next(request)
|
||||
|
||||
# Only process chat endpoint responses.
|
||||
if not request.url.path.startswith("/api/v1/chat"):
|
||||
return response
|
||||
|
||||
# Read body — collect streaming chunks.
|
||||
body_bytes = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode()
|
||||
|
||||
# Skip non-JSON bodies (shouldn't happen on /chat, but be safe).
|
||||
try:
|
||||
body = json.loads(body_bytes.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
return Response(
|
||||
content=body_bytes,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
if not isinstance(body, dict):
|
||||
return Response(
|
||||
content=body_bytes,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
# Walk top-level string fields and sanitise.
|
||||
sanitised_fields: list[str] = []
|
||||
for key, value in body.items():
|
||||
if isinstance(value, str):
|
||||
cleaned, changed = _sanitize_text(value)
|
||||
if changed:
|
||||
body[key] = cleaned
|
||||
sanitised_fields.append(key)
|
||||
|
||||
if sanitised_fields:
|
||||
logger.warning(
|
||||
"Sanitizer redacted prompt fragments",
|
||||
extra={
|
||||
"path": request.url.path,
|
||||
"fields": sanitised_fields,
|
||||
},
|
||||
)
|
||||
|
||||
new_body = json.dumps(body).encode("utf-8")
|
||||
headers = dict(response.headers)
|
||||
headers["content-length"] = str(len(new_body))
|
||||
|
||||
return Response(
|
||||
content=new_body,
|
||||
status_code=response.status_code,
|
||||
headers=headers,
|
||||
media_type="application/json",
|
||||
)
|
||||
@@ -1,33 +1,36 @@
|
||||
"""Auth routes: register, login, refresh, me.
|
||||
|
||||
Users and refresh tokens are kept in an in-memory dict until Step 12
|
||||
migrates them to PostgreSQL.
|
||||
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
||||
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
||||
SHA-256 hashes so plaintext never reaches the DB.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import bcrypt
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.config.settings import settings
|
||||
from app.db import get_session
|
||||
from app.models import RefreshToken, User
|
||||
from app.schemas import AuthTokens, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
# ── In-memory stores (replaced by PostgreSQL in Step 12) ─────────────
|
||||
_users: dict[str, dict[str, Any]] = {} # email → user record
|
||||
_refresh_tokens: dict[str, str] = {} # plain token → user_id
|
||||
|
||||
|
||||
# ── Internal helpers ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
@@ -36,30 +39,29 @@ def _verify_password(password: str, hashed: str) -> bool:
|
||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||
|
||||
|
||||
def _make_tokens(user_id: str, email: str, tier: str) -> AuthTokens:
|
||||
def _hash_token(plain_token: str) -> str:
|
||||
"""SHA-256 of the plain refresh token string."""
|
||||
return hashlib.sha256(plain_token.encode()).hexdigest()
|
||||
|
||||
|
||||
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
||||
"""Return (signed JWT, expires_at_ms)."""
|
||||
now = int(time.time())
|
||||
access_exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
access_payload = {
|
||||
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"email": email,
|
||||
"tier": tier,
|
||||
"exp": access_exp,
|
||||
"exp": exp,
|
||||
"iat": now,
|
||||
}
|
||||
access_token = jwt.encode(
|
||||
access_payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
refresh_token = str(uuid.uuid4())
|
||||
_refresh_tokens[refresh_token] = user_id
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_at=access_exp * 1000, # milliseconds for client
|
||||
)
|
||||
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||
return token, exp * 1000 # ms for client
|
||||
|
||||
|
||||
# ── Request bodies ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _RegisterRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
@@ -76,40 +78,117 @@ class _RefreshRequest(BaseModel):
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
||||
async def register(body: _RegisterRequest) -> AuthTokens:
|
||||
async def register(
|
||||
body: _RegisterRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Create a new account and return JWT tokens."""
|
||||
if body.email in _users:
|
||||
existing = await db.execute(select(User).where(User.email == body.email))
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
||||
user_id = str(uuid.uuid4())
|
||||
_users[body.email] = {
|
||||
"id": user_id,
|
||||
"email": body.email,
|
||||
"password_hash": _hash_password(body.password),
|
||||
"tier": "free",
|
||||
}
|
||||
return _make_tokens(user_id, body.email, "free")
|
||||
|
||||
user = User(
|
||||
id=str(uuid.uuid4()),
|
||||
email=body.email,
|
||||
password_hash=_hash_password(body.password),
|
||||
tier="free",
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush() # get user.id without committing
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=AuthTokens)
|
||||
async def login(body: _LoginRequest) -> AuthTokens:
|
||||
async def login(
|
||||
body: _LoginRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Validate credentials and return JWT tokens."""
|
||||
user = _users.get(body.email)
|
||||
if not user or not _verify_password(body.password, user["password_hash"]):
|
||||
result = await db.execute(select(User).where(User.email == body.email))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not _verify_password(body.password, user.password_hash):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||
return _make_tokens(user["id"], user["email"], user["tier"])
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=AuthTokens)
|
||||
async def refresh(body: _RefreshRequest) -> AuthTokens:
|
||||
async def refresh(
|
||||
body: _RefreshRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Rotate a refresh token and return a new token pair."""
|
||||
user_id = _refresh_tokens.pop(body.refresh_token, None)
|
||||
if user_id is None:
|
||||
token_hash = _hash_token(body.refresh_token)
|
||||
result = await db.execute(
|
||||
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||
)
|
||||
rt = result.scalar_one_or_none()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
||||
user = next((u for u in _users.values() if u["id"] == user_id), None)
|
||||
|
||||
# Rotate: delete old token, issue new one.
|
||||
await db.delete(rt)
|
||||
|
||||
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
||||
return _make_tokens(user["id"], user["email"], user["tier"])
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
new_rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=new_expires,
|
||||
)
|
||||
db.add(new_rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserProfile)
|
||||
|
||||
@@ -16,6 +16,7 @@ from typing import Any
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.tier_manager import tier_manager
|
||||
from app.schemas import BackupMetadata, UserProfile
|
||||
from app.storage.blob_store import BlobStore
|
||||
from app.storage.encryption import reject_if_tampered
|
||||
@@ -27,32 +28,11 @@ _blob_store = BlobStore()
|
||||
# In-memory backup metadata — replaced by PostgreSQL backup_metadata table in Step 12
|
||||
_backups: dict[str, list[dict[str, Any]]] = {} # user_id → list of backup records
|
||||
|
||||
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
|
||||
_TIER_BACKUP_LIMITS_GB: dict[str, int] = {
|
||||
"free": 0,
|
||||
"pro": 5,
|
||||
"power": 25,
|
||||
"team": -1, # unlimited
|
||||
}
|
||||
|
||||
|
||||
def _check_backup_quota(user_id: str, tier: str, size_bytes: int) -> None:
|
||||
def _check_backup_quota(user_id: str, size_bytes: int) -> None:
|
||||
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
|
||||
limit_gb = _TIER_BACKUP_LIMITS_GB.get(tier, 0)
|
||||
if limit_gb == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail="Backup is not available on the free tier",
|
||||
)
|
||||
if limit_gb == -1:
|
||||
return # unlimited
|
||||
limit_bytes = limit_gb * 1024**3
|
||||
used = sum(b["size_bytes"] for b in _backups.get(user_id, []))
|
||||
if used + size_bytes > limit_bytes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Backup quota exceeded for tier '{tier}'",
|
||||
)
|
||||
current = sum(b["size_bytes"] for b in _backups.get(user_id, []))
|
||||
tier_manager.enforce_backup_quota(user_id, current_bytes=current, additional_bytes=size_bytes)
|
||||
|
||||
|
||||
@router.put("")
|
||||
@@ -69,7 +49,7 @@ async def upload_backup(
|
||||
"""
|
||||
blob = await request.body()
|
||||
reject_if_tampered(blob, x_backup_checksum)
|
||||
_check_backup_quota(current_user.id, current_user.tier, len(blob))
|
||||
_check_backup_quota(current_user.id, len(blob))
|
||||
|
||||
s3_key = await _blob_store.upload(
|
||||
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
|
||||
|
||||
@@ -1,44 +1,25 @@
|
||||
"""Billing routes: Stripe checkout, webhook, subscription management.
|
||||
|
||||
Subscription records are kept in-memory until Step 12 migrates them to
|
||||
PostgreSQL (subscriptions table). Stripe calls are gracefully stubbed when
|
||||
STRIPE_SECRET_KEY is not configured, allowing local development without keys.
|
||||
Business logic lives in ``app.billing.stripe_service.StripeService``.
|
||||
The route layer handles HTTP concerns (request parsing, response shaping)
|
||||
and delegates everything else to the service singleton.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import stripe as stripe_lib
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||
from fastapi import APIRouter, Depends, Header, Request, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.config.settings import settings
|
||||
from app.billing.stripe_service import stripe_service
|
||||
from app.db import get_session
|
||||
from app.schemas import BillingTier, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||
|
||||
# In-memory subscriptions — replaced by PostgreSQL subscriptions table in Step 12
|
||||
_subscriptions: dict[str, dict[str, Any]] = {} # user_id → subscription record
|
||||
|
||||
_TIER_PRICE_IDS: dict[str, str] = {
|
||||
"pro": "price_pro_monthly", # replace with real Stripe price IDs
|
||||
"power": "price_power_monthly",
|
||||
"team": "price_team_monthly",
|
||||
}
|
||||
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
def _stripe_configured() -> bool:
|
||||
return bool(settings.STRIPE_SECRET_KEY)
|
||||
|
||||
|
||||
def _stripe() -> Any:
|
||||
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||
return stripe_lib
|
||||
|
||||
|
||||
# ── Request bodies ─────────────────────────────────────────────────────
|
||||
|
||||
@@ -57,40 +38,15 @@ async def create_checkout(
|
||||
|
||||
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
|
||||
"""
|
||||
if body.tier == "free":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot create a checkout session for the free tier",
|
||||
)
|
||||
|
||||
if _stripe_configured():
|
||||
price_id = _TIER_PRICE_IDS.get(body.tier)
|
||||
if not price_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unknown tier: {body.tier}",
|
||||
)
|
||||
s = _stripe()
|
||||
session = s.checkout.Session.create(
|
||||
payment_method_types=["card"],
|
||||
mode="subscription",
|
||||
line_items=[{"price": price_id, "quantity": 1}],
|
||||
success_url=(
|
||||
"https://app.adiuva.app/billing/success"
|
||||
"?session_id={CHECKOUT_SESSION_ID}"
|
||||
),
|
||||
cancel_url="https://app.adiuva.app/billing/cancel",
|
||||
metadata={"user_id": current_user.id, "tier": body.tier},
|
||||
)
|
||||
return {"checkout_url": session.url}
|
||||
|
||||
return {"checkout_url": "https://stripe.com/stub-checkout"}
|
||||
url = stripe_service.create_checkout_session(current_user.id, body.tier)
|
||||
return {"checkout_url": url}
|
||||
|
||||
|
||||
@router.post("/webhook", response_model=dict)
|
||||
async def stripe_webhook(
|
||||
request: Request,
|
||||
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Handle Stripe webhook events.
|
||||
|
||||
@@ -98,57 +54,17 @@ async def stripe_webhook(
|
||||
Returns 200 immediately when Stripe is not configured (local dev).
|
||||
"""
|
||||
payload = await request.body()
|
||||
|
||||
if not _stripe_configured():
|
||||
return {"ok": True}
|
||||
|
||||
try:
|
||||
s = _stripe()
|
||||
event = s.Webhook.construct_event(
|
||||
payload, stripe_signature, settings.STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except stripe_lib.error.SignatureVerificationError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid Stripe signature",
|
||||
)
|
||||
|
||||
event_type: str = event["type"]
|
||||
data: dict[str, Any] = event["data"]["object"]
|
||||
|
||||
if event_type == "checkout.session.completed":
|
||||
user_id = data.get("metadata", {}).get("user_id")
|
||||
tier = data.get("metadata", {}).get("tier", "free")
|
||||
sub_id = data.get("subscription")
|
||||
if user_id:
|
||||
_subscriptions[user_id] = {
|
||||
"tier": tier,
|
||||
"stripe_subscription_id": sub_id,
|
||||
"status": "active",
|
||||
"current_period_end": None,
|
||||
}
|
||||
|
||||
elif event_type == "customer.subscription.updated":
|
||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, then update tier
|
||||
pass
|
||||
|
||||
elif event_type == "customer.subscription.deleted":
|
||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free
|
||||
pass
|
||||
|
||||
elif event_type == "invoice.payment_failed":
|
||||
# TODO(Step12): flag subscription as past_due, notify user
|
||||
pass
|
||||
|
||||
await stripe_service.handle_webhook(payload, stripe_signature, db)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/subscription", response_model=dict)
|
||||
async def get_subscription(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, Any]:
|
||||
"""Return the current subscription info for the authenticated user."""
|
||||
sub = _subscriptions.get(current_user.id)
|
||||
sub = await stripe_service.get_subscription(current_user.id, db)
|
||||
if sub is None:
|
||||
return {
|
||||
"tier": current_user.tier,
|
||||
@@ -159,26 +75,11 @@ async def get_subscription(
|
||||
return sub
|
||||
|
||||
|
||||
@router.delete("/subscription", response_model=dict)
|
||||
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
|
||||
async def cancel_subscription(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Cancel the active subscription."""
|
||||
sub = _subscriptions.get(current_user.id)
|
||||
if sub is None or not sub.get("stripe_subscription_id"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No active subscription found",
|
||||
)
|
||||
|
||||
if _stripe_configured():
|
||||
s = _stripe()
|
||||
s.Subscription.cancel(sub["stripe_subscription_id"])
|
||||
|
||||
_subscriptions[current_user.id] = {
|
||||
**sub,
|
||||
"tier": "free",
|
||||
"status": "canceled",
|
||||
}
|
||||
|
||||
await stripe_service.cancel_subscription(current_user.id, db)
|
||||
return {"ok": True}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Plugins routes: browse and install plugins from the marketplace.
|
||||
|
||||
The catalog and installation records are kept in-memory as stubs.
|
||||
Step 10 replaces these with PluginRegistry, RevenueShare, and the plugins DB table.
|
||||
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes introduced
|
||||
in Step 10. Step 12 will swap those services' in-memory stores for
|
||||
PostgreSQL persistence.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -12,49 +13,12 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.config.settings import settings
|
||||
from app.marketplace.plugin_registry import registry
|
||||
from app.marketplace.revenue_share import revenue_share
|
||||
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/plugins", tags=["plugins"])
|
||||
|
||||
# ── In-memory catalog (Step 10 replaces with PluginRegistry + DB) ─────
|
||||
|
||||
_plugin_catalog: list[PluginManifest] = [
|
||||
PluginManifest(
|
||||
id="plugin-github-sync",
|
||||
name="GitHub Sync",
|
||||
description="Sync tasks with GitHub Issues and pull requests.",
|
||||
version="1.0.0",
|
||||
author="Adiuva",
|
||||
permissions=["read:tasks", "write:tasks"],
|
||||
category="productivity",
|
||||
price_cents=0,
|
||||
),
|
||||
PluginManifest(
|
||||
id="plugin-slack-notify",
|
||||
name="Slack Notifier",
|
||||
description="Post task and checkpoint updates to Slack channels.",
|
||||
version="1.2.0",
|
||||
author="Adiuva",
|
||||
permissions=["read:tasks", "read:checkpoints"],
|
||||
category="communication",
|
||||
price_cents=499,
|
||||
),
|
||||
PluginManifest(
|
||||
id="plugin-time-tracker",
|
||||
name="Time Tracker",
|
||||
description="Track time spent on tasks with automatic reporting.",
|
||||
version="0.9.1",
|
||||
author="Third Party",
|
||||
permissions=["read:tasks", "write:tasks"],
|
||||
category="productivity",
|
||||
price_cents=999,
|
||||
),
|
||||
]
|
||||
|
||||
# plugin_id → set of user_ids who have installed it
|
||||
_installations: dict[str, set[str]] = {}
|
||||
|
||||
|
||||
# ── Tier gate ─────────────────────────────────────────────────────────
|
||||
|
||||
@@ -67,43 +31,12 @@ def _require_plugin_tier(user: UserProfile) -> None:
|
||||
)
|
||||
|
||||
|
||||
# ── Filter + sort helpers ──────────────────────────────────────────────
|
||||
|
||||
def _apply_filters(
|
||||
plugins: list[PluginManifest],
|
||||
category: str | None,
|
||||
q: str | None,
|
||||
) -> list[PluginManifest]:
|
||||
result = plugins
|
||||
if category:
|
||||
result = [p for p in result if p.category == category]
|
||||
if q:
|
||||
q_lower = q.lower()
|
||||
result = [
|
||||
p for p in result
|
||||
if q_lower in p.name.lower() or q_lower in p.description.lower()
|
||||
]
|
||||
return result
|
||||
|
||||
|
||||
def _apply_sort(
|
||||
plugins: list[PluginManifest],
|
||||
sort: str,
|
||||
) -> list[PluginManifest]:
|
||||
if sort == "installs":
|
||||
return sorted(plugins, key=lambda p: len(_installations.get(p.id, set())), reverse=True)
|
||||
if sort == "rating":
|
||||
# Placeholder until Step 10 introduces avg_rating from DB
|
||||
return sorted(plugins, key=lambda p: -p.price_cents)
|
||||
return plugins # "newest" = catalog insertion order
|
||||
|
||||
|
||||
# ── Local detail schema ────────────────────────────────────────────────
|
||||
|
||||
class _PluginDetail(BaseModel):
|
||||
plugin: PluginManifest
|
||||
install_count: int
|
||||
ratings: list[Any] # Step 10 populates from plugin_reviews table
|
||||
ratings: list[Any] # Step 12 populates from plugin_reviews table
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
@@ -118,9 +51,7 @@ async def list_plugins(
|
||||
) -> PluginListResponse:
|
||||
"""Browse the plugin marketplace. Requires Power tier or above."""
|
||||
_require_plugin_tier(current_user)
|
||||
filtered = _apply_filters(_plugin_catalog, category, q)
|
||||
sorted_plugins = _apply_sort(filtered, sort)
|
||||
return PluginListResponse(plugins=sorted_plugins, total=len(sorted_plugins), page=page)
|
||||
return await registry.list_plugins(category=category, query=q, page=page, sort=sort)
|
||||
|
||||
|
||||
@router.get("/{plugin_id}", response_model=_PluginDetail)
|
||||
@@ -130,13 +61,13 @@ async def get_plugin(
|
||||
) -> _PluginDetail:
|
||||
"""Get full plugin details including install count. Requires Power tier or above."""
|
||||
_require_plugin_tier(current_user)
|
||||
plugin = next((p for p in _plugin_catalog if p.id == plugin_id), None)
|
||||
if plugin is None:
|
||||
entry = await registry.get_plugin(plugin_id)
|
||||
if entry is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||
return _PluginDetail(
|
||||
plugin=plugin,
|
||||
install_count=len(_installations.get(plugin_id, set())),
|
||||
ratings=[], # Step 10 populates from plugin_reviews table
|
||||
plugin=entry["manifest"],
|
||||
install_count=entry["install_count"],
|
||||
ratings=[], # Step 12 populates from plugin_reviews table
|
||||
)
|
||||
|
||||
|
||||
@@ -146,20 +77,21 @@ async def install_plugin(
|
||||
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, Any]:
|
||||
"""Install a plugin. Triggers Stripe Connect for paid plugins when configured.
|
||||
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
|
||||
|
||||
Requires Power tier or above.
|
||||
"""
|
||||
_require_plugin_tier(current_user)
|
||||
plugin = next((p for p in _plugin_catalog if p.id == plugin_id), None)
|
||||
if plugin is None:
|
||||
entry = await registry.get_plugin(plugin_id)
|
||||
if entry is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||
|
||||
if plugin.price_cents > 0 and settings.STRIPE_SECRET_KEY:
|
||||
# TODO(Step10): stripe.PaymentIntent.create with destination charge (70/30 split)
|
||||
pass
|
||||
await revenue_share.record_install(
|
||||
plugin_id=plugin_id,
|
||||
user_id=current_user.id,
|
||||
amount_cents=entry["manifest"].price_cents,
|
||||
)
|
||||
|
||||
_installations.setdefault(plugin_id, set()).add(current_user.id)
|
||||
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
|
||||
return {"ok": True, "download_url": download_url}
|
||||
|
||||
@@ -170,5 +102,5 @@ async def uninstall_plugin(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, bool]:
|
||||
"""Unregister a plugin installation."""
|
||||
_installations.get(plugin_id, set()).discard(current_user.id)
|
||||
await registry.record_uninstall(plugin_id)
|
||||
return {"ok": True}
|
||||
|
||||
@@ -14,6 +14,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.tier_manager import tier_manager
|
||||
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
|
||||
from app.storage.blob_store import BlobStore
|
||||
from app.storage.encryption import reject_if_tampered
|
||||
@@ -25,14 +26,6 @@ _blob_store = BlobStore()
|
||||
# In-memory record metadata — replaced by PostgreSQL storage_records table in Step 12
|
||||
_records: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
|
||||
_TIER_STORAGE_LIMITS_GB: dict[str, int] = {
|
||||
"free": 0,
|
||||
"pro": 5,
|
||||
"power": 25,
|
||||
"team": -1, # unlimited
|
||||
}
|
||||
|
||||
|
||||
# ── Local response schemas ─────────────────────────────────────────────
|
||||
|
||||
@@ -51,18 +44,10 @@ class _RecordMeta(BaseModel):
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
def _check_quota(user_id: str, tier: str, additional_bytes: int) -> None:
|
||||
def _check_quota(user_id: str, additional_bytes: int) -> None:
|
||||
"""Raise HTTP 402 if adding ``additional_bytes`` would exceed the tier limit."""
|
||||
limit_gb = _TIER_STORAGE_LIMITS_GB.get(tier, 0)
|
||||
if limit_gb == -1:
|
||||
return # unlimited
|
||||
limit_bytes = limit_gb * 1024**3
|
||||
used = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id)
|
||||
if used + additional_bytes > limit_bytes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Storage quota exceeded for tier '{tier}'",
|
||||
)
|
||||
current = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id)
|
||||
tier_manager.enforce_quota(user_id, current_bytes=current, additional_bytes=additional_bytes)
|
||||
|
||||
|
||||
def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]:
|
||||
@@ -83,7 +68,7 @@ async def create_record(
|
||||
) -> _CreateResponse:
|
||||
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
|
||||
reject_if_tampered(body.blob, body.checksum)
|
||||
_check_quota(current_user.id, current_user.tier, len(body.blob))
|
||||
_check_quota(current_user.id, len(body.blob))
|
||||
|
||||
record_id = str(uuid.uuid4())
|
||||
now = int(time.time() * 1000)
|
||||
@@ -159,7 +144,7 @@ async def update_record(
|
||||
|
||||
delta = len(body.blob) - record["size_bytes"]
|
||||
if delta > 0:
|
||||
_check_quota(current_user.id, current_user.tier, delta)
|
||||
_check_quota(current_user.id, delta)
|
||||
|
||||
s3_key = await _blob_store.upload(
|
||||
current_user.id, record["table"], record_id, body.blob, body.checksum
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from app.billing.stripe_service import stripe_service
|
||||
from app.billing.tier_manager import tier_manager
|
||||
|
||||
__all__ = ["stripe_service", "tier_manager"]
|
||||
|
||||
256
app/billing/stripe_service.py
Normal file
256
app/billing/stripe_service.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""Stripe service: checkout sessions, webhook handling, subscription management.
|
||||
|
||||
Subscription records are persisted in the PostgreSQL ``subscriptions`` table.
|
||||
All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not
|
||||
configured, enabling local development without live credentials.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import stripe as stripe_lib
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
# Stripe price IDs per tier — replace with real IDs in production .env
|
||||
TIER_PRICE_IDS: dict[str, str] = {
|
||||
"pro": "price_pro_monthly",
|
||||
"power": "price_power_monthly",
|
||||
"team": "price_team_monthly",
|
||||
}
|
||||
|
||||
|
||||
class StripeService:
|
||||
"""Wraps all Stripe interactions and owns subscription persistence."""
|
||||
|
||||
# ── Internal helpers ────────────────────────────────────────────────
|
||||
|
||||
def _configured(self) -> bool:
|
||||
return bool(settings.STRIPE_SECRET_KEY)
|
||||
|
||||
def _client(self) -> Any:
|
||||
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||
return stripe_lib
|
||||
|
||||
# ── Public API ──────────────────────────────────────────────────────
|
||||
|
||||
def create_checkout_session(
|
||||
self,
|
||||
user_id: str,
|
||||
tier: str,
|
||||
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
||||
cancel_url: str = "https://app.adiuva.app/billing/cancel",
|
||||
) -> str:
|
||||
"""Create a Stripe checkout session and return the URL.
|
||||
|
||||
Returns a stub URL when Stripe is not configured.
|
||||
Raises ``HTTP 400`` for the free tier or an unknown tier.
|
||||
"""
|
||||
if tier == "free":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot create a checkout session for the free tier",
|
||||
)
|
||||
|
||||
price_id = TIER_PRICE_IDS.get(tier)
|
||||
if not price_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unknown tier: {tier}",
|
||||
)
|
||||
|
||||
if not self._configured():
|
||||
return "https://stripe.com/stub-checkout"
|
||||
|
||||
s = self._client()
|
||||
session = s.checkout.Session.create(
|
||||
payment_method_types=["card"],
|
||||
mode="subscription",
|
||||
line_items=[{"price": price_id, "quantity": 1}],
|
||||
success_url=success_url,
|
||||
cancel_url=cancel_url,
|
||||
metadata={"user_id": user_id, "tier": tier},
|
||||
)
|
||||
return session.url
|
||||
|
||||
async def handle_webhook(
|
||||
self,
|
||||
payload: bytes,
|
||||
sig_header: str,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Process a Stripe webhook event.
|
||||
|
||||
Verifies the signature, then dispatches on event type.
|
||||
Raises ``HTTP 400`` on signature mismatch.
|
||||
No-ops when Stripe is not configured.
|
||||
"""
|
||||
if not self._configured():
|
||||
return
|
||||
|
||||
try:
|
||||
s = self._client()
|
||||
event = s.Webhook.construct_event(
|
||||
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except stripe_lib.error.SignatureVerificationError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid Stripe signature",
|
||||
)
|
||||
|
||||
event_type: str = event["type"]
|
||||
data: dict[str, Any] = event["data"]["object"]
|
||||
|
||||
if event_type == "checkout.session.completed":
|
||||
user_id = data.get("metadata", {}).get("user_id")
|
||||
tier = data.get("metadata", {}).get("tier", "free")
|
||||
sub_id = data.get("subscription")
|
||||
period_end_ts = data.get("current_period_end")
|
||||
period_end = (
|
||||
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||
if period_end_ts
|
||||
else None
|
||||
)
|
||||
if user_id:
|
||||
await self._upsert_subscription(
|
||||
db, user_id, sub_id, tier, "active", period_end
|
||||
)
|
||||
|
||||
elif event_type == "customer.subscription.updated":
|
||||
sub_id = data.get("id")
|
||||
new_status = data.get("status", "active")
|
||||
period_end_ts = data.get("current_period_end")
|
||||
period_end = (
|
||||
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||
if period_end_ts
|
||||
else None
|
||||
)
|
||||
if sub_id:
|
||||
await self._update_subscription_by_stripe_id(
|
||||
db, sub_id, status=new_status, current_period_end=period_end
|
||||
)
|
||||
|
||||
elif event_type == "customer.subscription.deleted":
|
||||
sub_id = data.get("id")
|
||||
if sub_id:
|
||||
await self._update_subscription_by_stripe_id(
|
||||
db, sub_id, tier="free", status="canceled"
|
||||
)
|
||||
|
||||
elif event_type == "invoice.payment_failed":
|
||||
sub_id = data.get("subscription")
|
||||
if sub_id:
|
||||
await self._update_subscription_by_stripe_id(
|
||||
db, sub_id, status="past_due"
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
async def get_subscription(
|
||||
self, user_id: str, db: AsyncSession
|
||||
) -> dict[str, Any] | None:
|
||||
"""Return the subscription record for ``user_id``, or ``None`` if absent."""
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None:
|
||||
return None
|
||||
return {
|
||||
"tier": sub.tier,
|
||||
"stripe_subscription_id": sub.stripe_subscription_id,
|
||||
"status": sub.status,
|
||||
"current_period_end": (
|
||||
int(sub.current_period_end.timestamp() * 1000)
|
||||
if sub.current_period_end
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
|
||||
"""Cancel the user's Stripe subscription and downgrade them to free.
|
||||
|
||||
Raises ``HTTP 404`` when no active subscription exists.
|
||||
"""
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None or not sub.stripe_subscription_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No active subscription found",
|
||||
)
|
||||
|
||||
if self._configured():
|
||||
s = self._client()
|
||||
s.Subscription.cancel(sub.stripe_subscription_id)
|
||||
|
||||
sub.tier = "free"
|
||||
sub.status = "canceled"
|
||||
await db.commit()
|
||||
|
||||
# ── Private DB helpers ───────────────────────────────────────────────
|
||||
|
||||
async def _upsert_subscription(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
stripe_subscription_id: str | None,
|
||||
tier: str,
|
||||
sub_status: str,
|
||||
current_period_end: datetime | None,
|
||||
) -> None:
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None:
|
||||
sub = Subscription(user_id=user_id)
|
||||
db.add(sub)
|
||||
sub.stripe_subscription_id = stripe_subscription_id
|
||||
sub.tier = tier
|
||||
sub.status = sub_status
|
||||
sub.current_period_end = current_period_end
|
||||
|
||||
async def _update_subscription_by_stripe_id(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
stripe_subscription_id: str,
|
||||
*,
|
||||
tier: str | None = None,
|
||||
status: str | None = None,
|
||||
current_period_end: datetime | None = None,
|
||||
) -> None:
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(
|
||||
Subscription.stripe_subscription_id == stripe_subscription_id
|
||||
)
|
||||
)
|
||||
sub = result.scalar_one_or_none()
|
||||
if sub is None:
|
||||
return
|
||||
if tier is not None:
|
||||
sub.tier = tier
|
||||
if status is not None:
|
||||
sub.status = status
|
||||
if current_period_end is not None:
|
||||
sub.current_period_end = current_period_end
|
||||
|
||||
|
||||
# Module-level singleton shared across the app.
|
||||
stripe_service = StripeService()
|
||||
189
app/billing/tier_manager.py
Normal file
189
app/billing/tier_manager.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Tier manager: feature matrix and quota enforcement.
|
||||
|
||||
``TierManager`` is the single source of truth for what each billing tier
|
||||
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
|
||||
Quota-enforcement helpers take ``tier`` directly — the caller already has it
|
||||
from ``current_user.tier`` (provided by ``get_current_user``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.schemas import BillingTier
|
||||
|
||||
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
|
||||
FEATURES: dict[str, dict[str, Any]] = {
|
||||
"free": {
|
||||
"agents": 3,
|
||||
"batch_active": 2,
|
||||
"cloud_storage_gb": 0,
|
||||
"backup_gb": 0,
|
||||
"providers": 1,
|
||||
"batch_builder": False,
|
||||
"plugin_marketplace": False,
|
||||
"sso": False,
|
||||
},
|
||||
"pro": {
|
||||
"agents": -1, # unlimited
|
||||
"batch_active": 10,
|
||||
"cloud_storage_gb": 5,
|
||||
"backup_gb": 5,
|
||||
"providers": -1,
|
||||
"batch_builder": False,
|
||||
"plugin_marketplace": False,
|
||||
"sso": False,
|
||||
},
|
||||
"power": {
|
||||
"agents": -1,
|
||||
"batch_active": -1, # unlimited
|
||||
"cloud_storage_gb": 25,
|
||||
"backup_gb": 25,
|
||||
"providers": -1,
|
||||
"batch_builder": True,
|
||||
"plugin_marketplace": True,
|
||||
"sso": False,
|
||||
},
|
||||
"team": {
|
||||
"agents": -1,
|
||||
"batch_active": -1,
|
||||
"cloud_storage_gb": -1, # unlimited
|
||||
"backup_gb": -1, # unlimited
|
||||
"providers": -1,
|
||||
"batch_builder": True,
|
||||
"plugin_marketplace": True,
|
||||
"sso": True,
|
||||
},
|
||||
}
|
||||
|
||||
# Requests-per-minute limit per tier.
|
||||
RATE_LIMITS: dict[str, int] = {
|
||||
"free": 20,
|
||||
"pro": 60,
|
||||
"power": 120,
|
||||
"team": 200,
|
||||
}
|
||||
|
||||
|
||||
class TierManager:
|
||||
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
|
||||
|
||||
# ── Tier lookup ─────────────────────────────────────────────────────
|
||||
|
||||
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||
"""Return the current billing tier for ``user_id`` from the DB.
|
||||
|
||||
Falls back to ``'free'`` when no subscription row exists.
|
||||
"""
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
tier: str | None = result.scalar_one_or_none()
|
||||
if tier is None or tier not in FEATURES:
|
||||
return "free"
|
||||
return tier # type: ignore[return-value]
|
||||
|
||||
# ── Feature access ───────────────────────────────────────────────────
|
||||
|
||||
def check_feature(self, tier: BillingTier, feature: str) -> bool:
|
||||
"""Return ``True`` if ``tier`` has ``feature`` enabled.
|
||||
|
||||
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
|
||||
"""
|
||||
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
return value != 0
|
||||
|
||||
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
|
||||
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
|
||||
if not self.check_feature(tier, feature):
|
||||
detail = (
|
||||
f"Feature '{feature}' requires {tier_name} tier or above."
|
||||
if tier_name
|
||||
else f"Feature '{feature}' is not available on your current tier."
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
# ── Rate limiting ────────────────────────────────────────────────────
|
||||
|
||||
def get_rate_limit(self, tier: BillingTier) -> int:
|
||||
"""Return the requests-per-minute limit for ``tier``."""
|
||||
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
||||
|
||||
# ── Storage quota ────────────────────────────────────────────────────
|
||||
|
||||
def enforce_quota(
|
||||
self,
|
||||
tier: BillingTier,
|
||||
current_bytes: int = 0,
|
||||
additional_bytes: int = 0,
|
||||
) -> None:
|
||||
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
|
||||
|
||||
``tier`` is the caller's current tier (from ``current_user.tier``).
|
||||
``current_bytes`` is the total bytes already stored (queried by caller).
|
||||
"""
|
||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||
if limit_gb == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Cloud storage is not available on the '{tier}' tier",
|
||||
)
|
||||
if limit_gb == -1:
|
||||
return # unlimited
|
||||
limit_bytes = limit_gb * 1024 ** 3
|
||||
if current_bytes + additional_bytes > limit_bytes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Storage quota exceeded for tier '{tier}'",
|
||||
)
|
||||
|
||||
def enforce_backup_quota(
|
||||
self,
|
||||
tier: BillingTier,
|
||||
current_bytes: int = 0,
|
||||
additional_bytes: int = 0,
|
||||
) -> None:
|
||||
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
|
||||
limit_gb: int = FEATURES[tier]["backup_gb"]
|
||||
if limit_gb == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Backup is not available on the '{tier}' tier",
|
||||
)
|
||||
if limit_gb == -1:
|
||||
return # unlimited
|
||||
limit_bytes = limit_gb * 1024 ** 3
|
||||
if current_bytes + additional_bytes > limit_bytes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Backup quota exceeded for tier '{tier}'",
|
||||
)
|
||||
|
||||
def check_quota(
|
||||
self,
|
||||
tier: BillingTier,
|
||||
current_bytes: int = 0,
|
||||
additional_bytes: int = 0,
|
||||
) -> bool:
|
||||
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
|
||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||
if limit_gb == 0:
|
||||
return False
|
||||
if limit_gb == -1:
|
||||
return True
|
||||
limit_bytes = limit_gb * 1024 ** 3
|
||||
return current_bytes + additional_bytes <= limit_bytes
|
||||
|
||||
|
||||
# Module-level singleton shared across the app.
|
||||
tier_manager = TierManager()
|
||||
40
app/db.py
Normal file
40
app/db.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Database engine, session factory, and base model.
|
||||
|
||||
All app code uses the async SQLAlchemy API. Alembic migrations use the
|
||||
synchronous psycopg2 URL for the CLI (see alembic/env.py).
|
||||
|
||||
Usage in routes:
|
||||
from app.db import get_session
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
async def my_route(db: AsyncSession = Depends(get_session)):
|
||||
result = await db.execute(select(User).where(User.email == email))
|
||||
user = result.scalar_one_or_none()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_pre_ping=True,
|
||||
echo=settings.ENV == "dev",
|
||||
)
|
||||
|
||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Shared declarative base for all ORM models."""
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""FastAPI dependency that yields an async DB session per request."""
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
11
app/main.py
11
app/main.py
@@ -3,6 +3,8 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
@@ -14,7 +16,9 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown: nothing to clean up for now
|
||||
# Shutdown: dispose SQLAlchemy connection pool
|
||||
from app.db import engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
@@ -33,6 +37,11 @@ def create_app() -> FastAPI:
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
# Middleware stack (Starlette inserts at position 0, so last-added = outermost).
|
||||
# Request flow: TierRateLimit → Sanitizer → CORS → Router
|
||||
# Response flow: Router → CORS → Sanitizer → TierRateLimit
|
||||
app.add_middleware(SanitizerMiddleware)
|
||||
app.add_middleware(TierRateLimitMiddleware)
|
||||
|
||||
from app.api.routes import auth, backup, billing, chat, plans, plugins, storage, vectors
|
||||
|
||||
|
||||
7
app/marketplace/__init__.py
Normal file
7
app/marketplace/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Plugin marketplace package.
|
||||
|
||||
Three service classes introduced in Step 10:
|
||||
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
|
||||
- ``ReviewQueue`` — approval workflow + security checklist
|
||||
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
|
||||
"""
|
||||
211
app/marketplace/plugin_registry.py
Normal file
211
app/marketplace/plugin_registry.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Plugin catalog registry.
|
||||
|
||||
Maintains the authoritative list of plugins, their review status, and
|
||||
aggregate install counts. Storage is in-memory until Step 12 migrates to
|
||||
the ``plugins`` PostgreSQL table.
|
||||
|
||||
Module-level singleton::
|
||||
|
||||
from app.marketplace.plugin_registry import registry
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.schemas import PluginListResponse, PluginManifest
|
||||
|
||||
# ── Pre-seeded approved plugins (mirrors the Step 8 stub catalog) ─────
|
||||
|
||||
_SEED_PLUGINS: list[dict[str, Any]] = [
|
||||
{
|
||||
"manifest": PluginManifest(
|
||||
id="plugin-github-sync",
|
||||
name="GitHub Sync",
|
||||
description="Sync tasks with GitHub Issues and pull requests.",
|
||||
version="1.0.0",
|
||||
author="Adiuva",
|
||||
permissions=["read:tasks", "write:tasks"],
|
||||
category="productivity",
|
||||
price_cents=0,
|
||||
),
|
||||
"status": "approved",
|
||||
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
|
||||
"install_count": 0,
|
||||
"avg_rating": 0.0,
|
||||
"rejection_reason": None,
|
||||
"submitted_at": int(time.time()),
|
||||
},
|
||||
{
|
||||
"manifest": PluginManifest(
|
||||
id="plugin-slack-notify",
|
||||
name="Slack Notifier",
|
||||
description="Post task and checkpoint updates to Slack channels.",
|
||||
version="1.2.0",
|
||||
author="Adiuva",
|
||||
permissions=["read:tasks", "read:checkpoints"],
|
||||
category="communication",
|
||||
price_cents=499,
|
||||
),
|
||||
"status": "approved",
|
||||
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||
"install_count": 0,
|
||||
"avg_rating": 0.0,
|
||||
"rejection_reason": None,
|
||||
"submitted_at": int(time.time()),
|
||||
},
|
||||
{
|
||||
"manifest": PluginManifest(
|
||||
id="plugin-time-tracker",
|
||||
name="Time Tracker",
|
||||
description="Track time spent on tasks with automatic reporting.",
|
||||
version="0.9.1",
|
||||
author="Third Party",
|
||||
permissions=["read:tasks", "write:tasks"],
|
||||
category="productivity",
|
||||
price_cents=999,
|
||||
),
|
||||
"status": "approved",
|
||||
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
|
||||
"install_count": 0,
|
||||
"avg_rating": 0.0,
|
||||
"rejection_reason": None,
|
||||
"submitted_at": int(time.time()),
|
||||
},
|
||||
]
|
||||
|
||||
_PAGE_SIZE = 20
|
||||
|
||||
|
||||
class PluginRegistry:
|
||||
"""In-process plugin catalog.
|
||||
|
||||
All mutating methods are ``async`` to make the future DB swap transparent
|
||||
to callers.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# plugin_id → entry dict (deep-copied so each instance is independent)
|
||||
self._catalog: dict[str, dict[str, Any]] = {
|
||||
e["manifest"].id: copy.deepcopy(e) for e in _SEED_PLUGINS
|
||||
}
|
||||
|
||||
# ── Queries ──────────────────────────────────────────────────────
|
||||
|
||||
async def list_plugins(
|
||||
self,
|
||||
category: str | None = None,
|
||||
query: str | None = None,
|
||||
page: int = 1,
|
||||
sort: Literal["rating", "installs", "newest"] = "newest",
|
||||
) -> PluginListResponse:
|
||||
"""Return a page of approved plugins, optionally filtered and sorted."""
|
||||
entries = [e for e in self._catalog.values() if e["status"] == "approved"]
|
||||
|
||||
if category:
|
||||
entries = [e for e in entries if e["manifest"].category == category]
|
||||
|
||||
if query:
|
||||
q_lower = query.lower()
|
||||
entries = [
|
||||
e
|
||||
for e in entries
|
||||
if q_lower in e["manifest"].name.lower()
|
||||
or q_lower in e["manifest"].description.lower()
|
||||
]
|
||||
|
||||
if sort == "installs":
|
||||
entries = sorted(entries, key=lambda e: e["install_count"], reverse=True)
|
||||
elif sort == "rating":
|
||||
entries = sorted(entries, key=lambda e: e["avg_rating"], reverse=True)
|
||||
# "newest" = catalog insertion order (dict preserves insertion in Python 3.7+)
|
||||
|
||||
total = len(entries)
|
||||
start = (page - 1) * _PAGE_SIZE
|
||||
page_entries = entries[start : start + _PAGE_SIZE]
|
||||
|
||||
return PluginListResponse(
|
||||
plugins=[e["manifest"] for e in page_entries],
|
||||
total=total,
|
||||
page=page,
|
||||
)
|
||||
|
||||
async def get_plugin(self, plugin_id: str) -> dict[str, Any] | None:
|
||||
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
|
||||
entry = self._catalog.get(plugin_id)
|
||||
if entry is None:
|
||||
return None
|
||||
return {
|
||||
"manifest": entry["manifest"],
|
||||
"status": entry["status"],
|
||||
"install_count": entry["install_count"],
|
||||
"avg_rating": entry["avg_rating"],
|
||||
}
|
||||
|
||||
# ── Mutations ────────────────────────────────────────────────────
|
||||
|
||||
async def submit_plugin(
|
||||
self,
|
||||
manifest: PluginManifest,
|
||||
package_s3_key: str,
|
||||
) -> str:
|
||||
"""Add *manifest* to the catalog with ``status='pending_review'``.
|
||||
|
||||
Returns the plugin_id. If a plugin with the same id already exists
|
||||
it is overwritten (re-submission after rejection).
|
||||
"""
|
||||
plugin_id = manifest.id or str(uuid.uuid4())
|
||||
self._catalog[plugin_id] = {
|
||||
"manifest": manifest,
|
||||
"status": "pending_review",
|
||||
"s3_package_key": package_s3_key,
|
||||
"install_count": 0,
|
||||
"avg_rating": 0.0,
|
||||
"rejection_reason": None,
|
||||
"submitted_at": int(time.time()),
|
||||
}
|
||||
return plugin_id
|
||||
|
||||
async def approve_plugin(self, plugin_id: str) -> None:
|
||||
"""Set *plugin_id* status to ``'approved'``.
|
||||
|
||||
Raises ``KeyError`` if the plugin is not found.
|
||||
"""
|
||||
if plugin_id not in self._catalog:
|
||||
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||
self._catalog[plugin_id]["status"] = "approved"
|
||||
self._catalog[plugin_id]["rejection_reason"] = None
|
||||
|
||||
async def reject_plugin(self, plugin_id: str, reason: str) -> None:
|
||||
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
|
||||
|
||||
Raises ``KeyError`` if the plugin is not found.
|
||||
"""
|
||||
if plugin_id not in self._catalog:
|
||||
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||
self._catalog[plugin_id]["status"] = "rejected"
|
||||
self._catalog[plugin_id]["rejection_reason"] = reason
|
||||
|
||||
async def record_install(self, plugin_id: str) -> None:
|
||||
"""Increment the install count for *plugin_id* (no-op if not found)."""
|
||||
if plugin_id in self._catalog:
|
||||
self._catalog[plugin_id]["install_count"] += 1
|
||||
|
||||
async def record_uninstall(self, plugin_id: str) -> None:
|
||||
"""Decrement the install count for *plugin_id*, floored at 0."""
|
||||
if plugin_id in self._catalog:
|
||||
current = self._catalog[plugin_id]["install_count"]
|
||||
self._catalog[plugin_id]["install_count"] = max(0, current - 1)
|
||||
|
||||
# ── Internal helpers used by ReviewQueue ─────────────────────────
|
||||
|
||||
def _get_pending_entries(self) -> list[dict[str, Any]]:
|
||||
"""Return all entries with status='pending_review' (synchronous helper)."""
|
||||
return [e for e in self._catalog.values() if e["status"] == "pending_review"]
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
registry = PluginRegistry()
|
||||
127
app/marketplace/plugin_review.py
Normal file
127
app/marketplace/plugin_review.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Plugin review workflow.
|
||||
|
||||
Manages the approval queue for newly submitted plugins and enforces a
|
||||
security checklist before any plugin is made visible in the marketplace.
|
||||
|
||||
Module-level singleton::
|
||||
|
||||
from app.marketplace.plugin_review import review_queue
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.marketplace.plugin_registry import registry
|
||||
from app.schemas import PluginManifest
|
||||
|
||||
# ── Security policy ───────────────────────────────────────────────────
|
||||
|
||||
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
|
||||
{
|
||||
"read:tasks",
|
||||
"write:tasks",
|
||||
"read:projects",
|
||||
"write:projects",
|
||||
"read:notes",
|
||||
"write:notes",
|
||||
"read:checkpoints",
|
||||
"write:checkpoints",
|
||||
"read:calendar",
|
||||
"write:calendar",
|
||||
}
|
||||
)
|
||||
|
||||
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
|
||||
|
||||
|
||||
def validate_manifest(manifest: PluginManifest) -> None:
|
||||
"""Enforce the plugin security checklist.
|
||||
|
||||
Raises:
|
||||
``ValueError`` on the first violation found. Callers should catch
|
||||
this and return HTTP 422 / reject the submission.
|
||||
|
||||
Checks:
|
||||
1. Plugin id matches ``^[a-z0-9-]+$``
|
||||
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
|
||||
3. No manifest field contains raw binary data
|
||||
"""
|
||||
if not _PLUGIN_ID_RE.match(manifest.id):
|
||||
raise ValueError(
|
||||
f"Invalid plugin id format: '{manifest.id}'. "
|
||||
"Only lowercase letters, digits, and hyphens are allowed."
|
||||
)
|
||||
|
||||
for perm in manifest.permissions:
|
||||
if perm not in ALLOWED_PERMISSIONS:
|
||||
raise ValueError(
|
||||
f"Unknown permission: '{perm}'. "
|
||||
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
|
||||
)
|
||||
|
||||
for field_name, value in manifest.model_dump().items():
|
||||
if isinstance(value, (bytes, bytearray)):
|
||||
raise ValueError(
|
||||
f"Binary content is not allowed in manifest field '{field_name}'."
|
||||
)
|
||||
|
||||
|
||||
class ReviewQueue:
|
||||
"""Approval queue for pending plugin submissions.
|
||||
|
||||
Delegates status changes to the shared ``PluginRegistry`` singleton so
|
||||
there is a single source of truth for plugin state.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Completed reviews — Step 12 stores in plugin_reviews table
|
||||
self._reviews: list[dict[str, Any]] = []
|
||||
|
||||
async def get_pending(self) -> list[dict[str, Any]]:
|
||||
"""Return all plugins currently awaiting review.
|
||||
|
||||
Each item is ``{plugin_id, manifest, submitted_at}``.
|
||||
"""
|
||||
entries = registry._get_pending_entries()
|
||||
return [
|
||||
{
|
||||
"plugin_id": e["manifest"].id,
|
||||
"manifest": e["manifest"],
|
||||
"submitted_at": e["submitted_at"],
|
||||
}
|
||||
for e in entries
|
||||
]
|
||||
|
||||
async def submit_review(
|
||||
self,
|
||||
plugin_id: str,
|
||||
reviewer_id: str,
|
||||
decision: Literal["approved", "rejected"],
|
||||
notes: str = "",
|
||||
) -> None:
|
||||
"""Record a review decision and update the plugin's status.
|
||||
|
||||
Raises:
|
||||
``KeyError`` if *plugin_id* is not found in the registry.
|
||||
"""
|
||||
if decision == "approved":
|
||||
await registry.approve_plugin(plugin_id)
|
||||
else:
|
||||
await registry.reject_plugin(plugin_id, reason=notes)
|
||||
|
||||
self._reviews.append(
|
||||
{
|
||||
"plugin_id": plugin_id,
|
||||
"reviewer_id": reviewer_id,
|
||||
"decision": decision,
|
||||
"notes": notes,
|
||||
"reviewed_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
review_queue = ReviewQueue()
|
||||
205
app/marketplace/revenue_share.py
Normal file
205
app/marketplace/revenue_share.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Revenue share tracking and Stripe Connect payouts.
|
||||
|
||||
Records every plugin installation as a revenue event and facilitates
|
||||
70 % / 30 % payouts to developers via Stripe Connect. Storage is
|
||||
in-memory until Step 12 migrates to the ``revenue_events`` table.
|
||||
|
||||
Module-level singleton::
|
||||
|
||||
from app.marketplace.revenue_share import revenue_share
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import stripe as stripe_lib
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.marketplace.plugin_registry import registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Revenue split constants ───────────────────────────────────────────
|
||||
|
||||
DEVELOPER_SHARE: float = 0.70
|
||||
PLATFORM_SHARE: float = 0.30
|
||||
|
||||
|
||||
class RevenueShare:
|
||||
"""Records installation revenue events and coordinates developer payouts.
|
||||
|
||||
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
|
||||
is not configured, consistent with the rest of the billing layer.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Step 12 replaces with revenue_events DB table
|
||||
self._events: list[dict[str, Any]] = []
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _stripe_configured() -> bool:
|
||||
return bool(settings.STRIPE_SECRET_KEY)
|
||||
|
||||
@staticmethod
|
||||
def _stripe() -> Any:
|
||||
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||
return stripe_lib
|
||||
|
||||
# ── Core operations ──────────────────────────────────────────────
|
||||
|
||||
async def record_install(
|
||||
self,
|
||||
plugin_id: str,
|
||||
user_id: str,
|
||||
amount_cents: int,
|
||||
) -> None:
|
||||
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
|
||||
|
||||
For free plugins (``amount_cents == 0``) no payment is initiated but
|
||||
the event is still recorded for analytics.
|
||||
|
||||
For paid plugins the developer receives 70 % via a Stripe Connect
|
||||
destination charge. If Stripe is not configured or the charge fails
|
||||
the installation still succeeds (the event is recorded and the install
|
||||
count is incremented) — a warning is logged for monitoring.
|
||||
"""
|
||||
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
|
||||
stripe_transfer_id: str | None = None
|
||||
|
||||
if amount_cents > 0 and self._stripe_configured():
|
||||
plugin_entry = registry._catalog.get(plugin_id)
|
||||
developer_stripe_account: str | None = None
|
||||
if plugin_entry:
|
||||
# Step 12: look up developer's Stripe account from DB
|
||||
# For now, the author field is used as a placeholder key.
|
||||
developer_stripe_account = None # no real account yet
|
||||
|
||||
if developer_stripe_account:
|
||||
try:
|
||||
s = self._stripe()
|
||||
transfer = s.Transfer.create(
|
||||
amount=developer_share_cents,
|
||||
currency="eur",
|
||||
destination=developer_stripe_account,
|
||||
description=f"Revenue share for plugin {plugin_id}",
|
||||
metadata={"plugin_id": plugin_id, "user_id": user_id},
|
||||
)
|
||||
stripe_transfer_id = transfer["id"]
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Stripe Connect transfer failed for plugin %s: %s",
|
||||
plugin_id,
|
||||
exc,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"No Stripe account on file for plugin %s developer; "
|
||||
"skipping transfer.",
|
||||
plugin_id,
|
||||
)
|
||||
|
||||
self._events.append(
|
||||
{
|
||||
"plugin_id": plugin_id,
|
||||
"user_id": user_id,
|
||||
"amount_cents": amount_cents,
|
||||
"developer_share_cents": developer_share_cents,
|
||||
"stripe_transfer_id": stripe_transfer_id,
|
||||
"paid_at": None,
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
await registry.record_install(plugin_id)
|
||||
|
||||
async def get_earnings(
|
||||
self,
|
||||
developer_id: str,
|
||||
period: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return aggregated earnings for *developer_id*.
|
||||
|
||||
``period`` is an optional ``YYYY-MM`` string to restrict the window.
|
||||
|
||||
Returns::
|
||||
|
||||
{
|
||||
"developer_id": str,
|
||||
"period": str | None,
|
||||
"total_installs": int,
|
||||
"total_revenue_cents": int,
|
||||
"developer_share_cents": int,
|
||||
}
|
||||
"""
|
||||
# Find plugin ids belonging to this developer
|
||||
developer_plugin_ids: set[str] = {
|
||||
pid
|
||||
for pid, entry in registry._catalog.items()
|
||||
if entry["manifest"].author == developer_id
|
||||
}
|
||||
|
||||
events = [e for e in self._events if e["plugin_id"] in developer_plugin_ids]
|
||||
|
||||
if period:
|
||||
# Filter by YYYY-MM prefix of the created_at timestamp
|
||||
events = [
|
||||
e
|
||||
for e in events
|
||||
if time.strftime("%Y-%m", time.gmtime(e["created_at"])) == period
|
||||
]
|
||||
|
||||
return {
|
||||
"developer_id": developer_id,
|
||||
"period": period,
|
||||
"total_installs": len(events),
|
||||
"total_revenue_cents": sum(e["amount_cents"] for e in events),
|
||||
"developer_share_cents": sum(e["developer_share_cents"] for e in events),
|
||||
}
|
||||
|
||||
async def payout_developer(self, plugin_id: str, period: str) -> None:
|
||||
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
|
||||
|
||||
Marks processed events with ``paid_at`` timestamp.
|
||||
Stubs gracefully when Stripe is not configured.
|
||||
"""
|
||||
unpaid = [
|
||||
e
|
||||
for e in self._events
|
||||
if e["plugin_id"] == plugin_id
|
||||
and e["paid_at"] is None
|
||||
and time.strftime("%Y-%m", time.gmtime(e["created_at"])) == period
|
||||
]
|
||||
|
||||
total_dev_share = sum(e["developer_share_cents"] for e in unpaid)
|
||||
if total_dev_share <= 0 or not unpaid:
|
||||
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
|
||||
return
|
||||
|
||||
if self._stripe_configured():
|
||||
plugin_entry = registry._catalog.get(plugin_id)
|
||||
developer_stripe_account: str | None = None # Step 12: fetch from DB
|
||||
if plugin_entry and developer_stripe_account:
|
||||
try:
|
||||
s = self._stripe()
|
||||
s.Transfer.create(
|
||||
amount=total_dev_share,
|
||||
currency="eur",
|
||||
destination=developer_stripe_account,
|
||||
description=f"Payout for plugin {plugin_id} period {period}",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
|
||||
return
|
||||
|
||||
paid_ts = int(time.time())
|
||||
for event in unpaid:
|
||||
event["paid_at"] = paid_ts
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
revenue_share = RevenueShare()
|
||||
269
app/models.py
Normal file
269
app/models.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""SQLAlchemy ORM models for all persistent tables.
|
||||
|
||||
Only auth, billing, storage metadata, and marketplace data live here.
|
||||
User content (notes, tasks, etc.) is NEVER persisted server-side —
|
||||
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
|
||||
|
||||
Table inventory:
|
||||
users — account credentials + tier
|
||||
refresh_tokens — hashed refresh token store
|
||||
subscriptions — Stripe subscription records
|
||||
storage_records — S3 blob metadata (no plaintext)
|
||||
backup_metadata — encrypted backup manifests
|
||||
plugins — marketplace plugin catalog
|
||||
plugin_installations — per-user install records
|
||||
plugin_reviews — admin review decisions
|
||||
revenue_events — Stripe Connect 70/30 split ledger
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Boolean,
|
||||
DateTime,
|
||||
Enum,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db import Base
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _uuid() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
# ── Enum types ────────────────────────────────────────────────────────────
|
||||
|
||||
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
||||
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
|
||||
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
|
||||
|
||||
|
||||
# ── Models ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
refresh_tokens: Mapped[list[RefreshToken]] = relationship(
|
||||
back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
subscription: Mapped[Subscription | None] = relationship(
|
||||
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class RefreshToken(Base):
|
||||
__tablename__ = "refresh_tokens"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
||||
|
||||
|
||||
class Subscription(Base):
|
||||
__tablename__ = "subscriptions"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, unique=True, index=True
|
||||
)
|
||||
stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
|
||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||
status: Mapped[str] = mapped_column(String(50), nullable=False, default="free")
|
||||
current_period_end: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship(back_populates="subscription")
|
||||
|
||||
|
||||
class StorageRecord(Base):
|
||||
__tablename__ = "storage_records"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class BackupMetadata(Base):
|
||||
__tablename__ = "backup_metadata"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
class Plugin(Base):
|
||||
__tablename__ = "plugins"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
|
||||
# nullable until developer account system is built
|
||||
author_id: Mapped[str | None] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
|
||||
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
|
||||
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
|
||||
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
|
||||
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
submitted_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
installations: Mapped[list[PluginInstallation]] = relationship(
|
||||
back_populates="plugin", cascade="all, delete-orphan"
|
||||
)
|
||||
reviews: Mapped[list[PluginReview]] = relationship(
|
||||
back_populates="plugin", cascade="all, delete-orphan"
|
||||
)
|
||||
revenue_events: Mapped[list[RevenueEvent]] = relationship(
|
||||
back_populates="plugin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class PluginInstallation(Base):
|
||||
__tablename__ = "plugin_installations"
|
||||
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
plugin_id: Mapped[str] = mapped_column(
|
||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
installed_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
plugin: Mapped[Plugin] = relationship(back_populates="installations")
|
||||
|
||||
|
||||
class PluginReview(Base):
|
||||
__tablename__ = "plugin_reviews"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
plugin_id: Mapped[str] = mapped_column(
|
||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
reviewer_id: Mapped[str | None] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
|
||||
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
reviewed_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
|
||||
|
||||
|
||||
class RevenueEvent(Base):
|
||||
__tablename__ = "revenue_events"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
plugin_id: Mapped[str] = mapped_column(
|
||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
|
||||
304
tests/test_middleware.py
Normal file
304
tests/test_middleware.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""Tests for Step 9 middleware: auth, rate limiting, and sanitizer.
|
||||
|
||||
Auth tests: validated via GET /api/v1/auth/me (requires a Bearer JWT).
|
||||
Rate limit: use unique user UUIDs per test so windows are independent;
|
||||
the free-tier threshold (20 req/min) is exercised directly.
|
||||
Sanitizer: the orchestrator is mocked to inject controlled prompt
|
||||
fragments, and the chat endpoint response body is inspected.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from jose import jwt
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.main import app
|
||||
from app.schemas import ChatResponse
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CHAT_BODY = {
|
||||
"message": "hello",
|
||||
"context": {
|
||||
"user_profile": {},
|
||||
"relevant_documents": [],
|
||||
"recent_tasks": [],
|
||||
"conversation_history": [],
|
||||
},
|
||||
"execution_mode": "direct",
|
||||
}
|
||||
|
||||
|
||||
def _make_jwt(
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
email: str = "test@example.com",
|
||||
tier: str = "free",
|
||||
exp_offset: int = 3600,
|
||||
secret: str | None = None,
|
||||
include_sub: bool = True,
|
||||
) -> str:
|
||||
"""Mint a test JWT signed with the configured (or custom) secret."""
|
||||
uid = user_id or str(uuid.uuid4())
|
||||
now = int(time.time())
|
||||
payload: dict = {
|
||||
"email": email,
|
||||
"tier": tier,
|
||||
"exp": now + exp_offset,
|
||||
"iat": now,
|
||||
}
|
||||
if include_sub:
|
||||
payload["sub"] = uid
|
||||
key = secret or settings.JWT_SECRET
|
||||
return jwt.encode(payload, key, algorithm=settings.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def _auth_header(token: str) -> dict[str, str]:
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAuthMiddleware:
|
||||
"""Tests exercised via GET /api/v1/auth/me."""
|
||||
|
||||
def test_valid_token_returns_profile(self) -> None:
|
||||
uid = str(uuid.uuid4())
|
||||
token = _make_jwt(user_id=uid, email="alice@example.com", tier="pro")
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == uid
|
||||
assert data["email"] == "alice@example.com"
|
||||
assert data["tier"] == "pro"
|
||||
|
||||
def test_missing_token_returns_401(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_expired_token_returns_401(self) -> None:
|
||||
token = _make_jwt(exp_offset=-1) # already expired
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_wrong_signature_returns_401(self) -> None:
|
||||
token = _make_jwt(secret="totally-wrong-secret")
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_missing_sub_claim_returns_401(self) -> None:
|
||||
token = _make_jwt(include_sub=False)
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_malformed_token_returns_401(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get(
|
||||
"/api/v1/auth/me", headers={"Authorization": "Bearer not.a.jwt"}
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Rate limiter middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRateLimitMiddleware:
|
||||
"""Each test uses a fresh unique user_id so windows never collide."""
|
||||
|
||||
def _unique_token(self, tier: str = "free") -> str:
|
||||
return _make_jwt(user_id=str(uuid.uuid4()), tier=tier)
|
||||
|
||||
def test_free_tier_allows_up_to_20_requests(self) -> None:
|
||||
token = self._unique_token("free")
|
||||
with TestClient(app) as client:
|
||||
for _ in range(20):
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_free_tier_blocks_21st_request(self) -> None:
|
||||
token = self._unique_token("free")
|
||||
with TestClient(app) as client:
|
||||
for _ in range(20):
|
||||
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 429
|
||||
|
||||
def test_429_includes_retry_after_header(self) -> None:
|
||||
token = self._unique_token("free")
|
||||
with TestClient(app) as client:
|
||||
for _ in range(20):
|
||||
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 429
|
||||
assert "retry-after" in resp.headers
|
||||
retry_after = int(resp.headers["retry-after"])
|
||||
assert retry_after >= 1
|
||||
|
||||
def test_429_response_has_detail_field(self) -> None:
|
||||
token = self._unique_token("free")
|
||||
with TestClient(app) as client:
|
||||
for _ in range(20):
|
||||
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 429
|
||||
assert "detail" in resp.json()
|
||||
|
||||
def test_pro_tier_allows_60_requests(self) -> None:
|
||||
token = self._unique_token("pro")
|
||||
with TestClient(app) as client:
|
||||
# Sample: first 60 succeed, 61st is blocked.
|
||||
for _ in range(60):
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 200
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 429
|
||||
|
||||
def test_independent_users_have_separate_windows(self) -> None:
|
||||
token_a = self._unique_token("free")
|
||||
token_b = self._unique_token("free")
|
||||
with TestClient(app) as client:
|
||||
# Exhaust user A's quota.
|
||||
for _ in range(20):
|
||||
client.get("/api/v1/auth/me", headers=_auth_header(token_a))
|
||||
assert (
|
||||
client.get(
|
||||
"/api/v1/auth/me", headers=_auth_header(token_a)
|
||||
).status_code
|
||||
== 429
|
||||
)
|
||||
# User B's quota is untouched.
|
||||
resp_b = client.get("/api/v1/auth/me", headers=_auth_header(token_b))
|
||||
assert resp_b.status_code == 200
|
||||
|
||||
def test_exempt_path_register_never_rate_limited(self) -> None:
|
||||
"""POST /auth/register is exempt — 25 calls should never return 429."""
|
||||
with TestClient(app) as client:
|
||||
for i in range(25):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": f"user{i}_{uuid.uuid4()}@example.com", "password": "pw"},
|
||||
)
|
||||
# 201 on first, 409 on email collision — but never 429.
|
||||
assert resp.status_code != 429
|
||||
|
||||
def test_exempt_path_login_never_rate_limited(self) -> None:
|
||||
"""POST /auth/login is exempt — multiple failed attempts are not rate-limited."""
|
||||
with TestClient(app) as client:
|
||||
for _ in range(25):
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "nosuchuser@example.com", "password": "wrong"},
|
||||
)
|
||||
assert resp.status_code != 429
|
||||
|
||||
def test_exempt_path_health_never_rate_limited(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
for _ in range(25):
|
||||
resp = client.get("/api/v1/health")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sanitizer middleware
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSanitizerMiddleware:
|
||||
"""Mock ``orchestrate`` to inject controlled strings into chat responses."""
|
||||
|
||||
_CHAT_PATH = "/api/v1/chat"
|
||||
|
||||
def _token(self) -> str:
|
||||
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
|
||||
|
||||
def _post_chat(self, client: TestClient, response_text: str) -> dict:
|
||||
mock_response = ChatResponse(response=response_text, actions=[])
|
||||
with patch(
|
||||
"app.api.routes.chat.orchestrate",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
resp = client.post(
|
||||
self._CHAT_PATH,
|
||||
json=_CHAT_BODY,
|
||||
headers=_auth_header(self._token()),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
return resp.json()
|
||||
|
||||
def test_clean_response_passes_through_unchanged(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(client, "Sure, I created the task for you.")
|
||||
assert data["response"] == "Sure, I created the task for you."
|
||||
|
||||
def test_strips_system_prompt_opener(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(
|
||||
client, "You are an intent classifier. Route to task_agent."
|
||||
)
|
||||
assert "You are" not in data["response"]
|
||||
assert "[REDACTED]" in data["response"]
|
||||
|
||||
def test_strips_known_fingerprint(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(
|
||||
client, "Respond with just the agent name and nothing else."
|
||||
)
|
||||
assert data["response"] == "[REDACTED]"
|
||||
|
||||
def test_strips_tool_schema_fragment(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(
|
||||
client, 'Here is the schema: {"type": "function", "name": "foo"}'
|
||||
)
|
||||
assert '"type": "function"' not in data["response"]
|
||||
|
||||
def test_strips_reasoning_tag(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(
|
||||
client, "<thinking>I should route this to calendar_agent</thinking>Done."
|
||||
)
|
||||
assert "<thinking>" not in data["response"]
|
||||
assert "[REDACTED]" in data["response"]
|
||||
|
||||
def test_strips_available_agents_fragment(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(
|
||||
client, "Available agents: task_agent, calendar_agent"
|
||||
)
|
||||
assert "[REDACTED]" in data["response"]
|
||||
|
||||
def test_sanitizer_does_not_activate_for_non_chat_path(self) -> None:
|
||||
"""GET /api/v1/plans/playbook should pass through the sanitizer untouched."""
|
||||
token = self._token()
|
||||
with TestClient(app) as client:
|
||||
resp = client.get(
|
||||
"/api/v1/plans/playbook",
|
||||
headers=_auth_header(token),
|
||||
)
|
||||
# The sanitizer should not interfere — just check it returns something
|
||||
# (200 or whatever the route returns; we only care it's not broken).
|
||||
assert resp.status_code in (200, 401, 403, 404)
|
||||
|
||||
def test_sanitizer_preserves_empty_response(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
data = self._post_chat(client, "")
|
||||
assert data["response"] == ""
|
||||
387
tests/test_plugins.py
Normal file
387
tests/test_plugins.py
Normal file
@@ -0,0 +1,387 @@
|
||||
"""Tests for Step 10: Plugin Marketplace.
|
||||
|
||||
Covers:
|
||||
- PluginRegistry: catalog management, filtering, sorting, install counts
|
||||
- ReviewQueue: pending queue, review decisions, manifest security checklist
|
||||
- RevenueShare: install event recording, earnings aggregation
|
||||
- Route integration: tier gate, list/get/install/uninstall via TestClient
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi.testclient import TestClient
|
||||
from jose import jwt
|
||||
from unittest.mock import patch
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.main import app
|
||||
from app.marketplace.plugin_registry import PluginRegistry
|
||||
from app.marketplace.plugin_review import ReviewQueue, validate_manifest
|
||||
from app.marketplace.revenue_share import RevenueShare
|
||||
from app.schemas import PluginManifest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_jwt(tier: str = "power", user_id: str | None = None) -> str:
|
||||
uid = user_id or str(uuid.uuid4())
|
||||
now = int(time.time())
|
||||
payload = {
|
||||
"sub": uid,
|
||||
"email": f"{uid[:8]}@example.com",
|
||||
"tier": tier,
|
||||
"exp": now + 3600,
|
||||
"iat": now,
|
||||
}
|
||||
return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def _auth(tier: str = "power") -> dict[str, str]:
|
||||
return {"Authorization": f"Bearer {_make_jwt(tier)}"}
|
||||
|
||||
|
||||
def _fresh_manifest(
|
||||
plugin_id: str | None = None,
|
||||
category: str = "productivity",
|
||||
price_cents: int = 0,
|
||||
permissions: list[str] | None = None,
|
||||
) -> PluginManifest:
|
||||
pid = plugin_id or f"plugin-{uuid.uuid4().hex[:8]}"
|
||||
return PluginManifest(
|
||||
id=pid,
|
||||
name=f"Plugin {pid}",
|
||||
description=f"Description for {pid}",
|
||||
version="1.0.0",
|
||||
author="test-author",
|
||||
permissions=permissions or ["read:tasks"],
|
||||
category=category,
|
||||
price_cents=price_cents,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PluginRegistry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPluginRegistry:
|
||||
"""Each test uses a fresh PluginRegistry instance to avoid catalog pollution."""
|
||||
|
||||
@pytest.fixture
|
||||
def reg(self) -> PluginRegistry:
|
||||
return PluginRegistry()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_seed_plugins_are_approved(self, reg: PluginRegistry) -> None:
|
||||
result = await reg.list_plugins()
|
||||
assert result.total == 3
|
||||
assert all(p.id.startswith("plugin-") for p in result.plugins)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_approved_only(self, reg: PluginRegistry) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(manifest, "plugins/key.zip")
|
||||
result = await reg.list_plugins()
|
||||
ids = [p.id for p in result.plugins]
|
||||
assert manifest.id not in ids # still pending
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_filter_by_category(self, reg: PluginRegistry) -> None:
|
||||
result = await reg.list_plugins(category="communication")
|
||||
assert result.total == 1
|
||||
assert result.plugins[0].id == "plugin-slack-notify"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_filter_by_query(self, reg: PluginRegistry) -> None:
|
||||
result = await reg.list_plugins(query="time")
|
||||
assert result.total == 1
|
||||
assert result.plugins[0].id == "plugin-time-tracker"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sort_by_installs(self, reg: PluginRegistry) -> None:
|
||||
await reg.record_install("plugin-slack-notify")
|
||||
await reg.record_install("plugin-slack-notify")
|
||||
result = await reg.list_plugins(sort="installs")
|
||||
assert result.plugins[0].id == "plugin-slack-notify"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_plugin_found(self, reg: PluginRegistry) -> None:
|
||||
entry = await reg.get_plugin("plugin-github-sync")
|
||||
assert entry is not None
|
||||
assert entry["manifest"].id == "plugin-github-sync"
|
||||
assert "install_count" in entry
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_plugin_not_found(self, reg: PluginRegistry) -> None:
|
||||
entry = await reg.get_plugin("no-such-plugin")
|
||||
assert entry is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_sets_pending(self, reg: PluginRegistry) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
plugin_id = await reg.submit_plugin(manifest, "key.zip")
|
||||
assert plugin_id == manifest.id
|
||||
assert reg._catalog[plugin_id]["status"] == "pending_review"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_makes_visible(self, reg: PluginRegistry) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(manifest, "key.zip")
|
||||
await reg.approve_plugin(manifest.id)
|
||||
result = await reg.list_plugins()
|
||||
assert manifest.id in [p.id for p in result.plugins]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reject_stores_reason(self, reg: PluginRegistry) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(manifest, "key.zip")
|
||||
await reg.reject_plugin(manifest.id, reason="Unsafe permissions")
|
||||
assert reg._catalog[manifest.id]["status"] == "rejected"
|
||||
assert reg._catalog[manifest.id]["rejection_reason"] == "Unsafe permissions"
|
||||
result = await reg.list_plugins()
|
||||
assert manifest.id not in [p.id for p in result.plugins]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_unknown_raises_key_error(self, reg: PluginRegistry) -> None:
|
||||
with pytest.raises(KeyError):
|
||||
await reg.approve_plugin("ghost-plugin")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_install_increments_count(self, reg: PluginRegistry) -> None:
|
||||
await reg.record_install("plugin-github-sync")
|
||||
entry = await reg.get_plugin("plugin-github-sync")
|
||||
assert entry is not None
|
||||
assert entry["install_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_uninstall_decrements_count(self, reg: PluginRegistry) -> None:
|
||||
await reg.record_install("plugin-github-sync")
|
||||
await reg.record_install("plugin-github-sync")
|
||||
await reg.record_uninstall("plugin-github-sync")
|
||||
entry = await reg.get_plugin("plugin-github-sync")
|
||||
assert entry is not None
|
||||
assert entry["install_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_uninstall_floors_at_zero(self, reg: PluginRegistry) -> None:
|
||||
await reg.record_uninstall("plugin-github-sync") # already 0
|
||||
entry = await reg.get_plugin("plugin-github-sync")
|
||||
assert entry is not None
|
||||
assert entry["install_count"] == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ReviewQueue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReviewQueue:
|
||||
@pytest.fixture
|
||||
def reg(self) -> PluginRegistry:
|
||||
return PluginRegistry()
|
||||
|
||||
@pytest.fixture
|
||||
def queue(self, reg: PluginRegistry) -> ReviewQueue:
|
||||
# Patch the 'registry' name as bound inside plugin_review.py
|
||||
with patch("app.marketplace.plugin_review.registry", reg):
|
||||
yield ReviewQueue()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pending_returns_submitted_plugins(
|
||||
self, reg: PluginRegistry, queue: ReviewQueue
|
||||
) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(manifest, "key.zip")
|
||||
pending = await queue.get_pending()
|
||||
assert any(p["plugin_id"] == manifest.id for p in pending)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_review_approved(
|
||||
self, reg: PluginRegistry, queue: ReviewQueue
|
||||
) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(manifest, "key.zip")
|
||||
await queue.submit_review(manifest.id, "reviewer-1", "approved", "Looks good")
|
||||
assert reg._catalog[manifest.id]["status"] == "approved"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_review_rejected(
|
||||
self, reg: PluginRegistry, queue: ReviewQueue
|
||||
) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(manifest, "key.zip")
|
||||
await queue.submit_review(manifest.id, "reviewer-1", "rejected", "Bad permissions")
|
||||
assert reg._catalog[manifest.id]["status"] == "rejected"
|
||||
|
||||
def test_validate_manifest_ok(self) -> None:
|
||||
manifest = _fresh_manifest(permissions=["read:tasks", "write:notes"])
|
||||
validate_manifest(manifest) # should not raise
|
||||
|
||||
def test_validate_manifest_unknown_permission(self) -> None:
|
||||
manifest = _fresh_manifest(permissions=["read:tasks", "read:secrets"])
|
||||
with pytest.raises(ValueError, match="Unknown permission"):
|
||||
validate_manifest(manifest)
|
||||
|
||||
def test_validate_manifest_invalid_id_format(self) -> None:
|
||||
manifest = _fresh_manifest(plugin_id="Plugin_ID_Invalid")
|
||||
with pytest.raises(ValueError, match="Invalid plugin id format"):
|
||||
validate_manifest(manifest)
|
||||
|
||||
def test_validate_manifest_id_with_uppercase(self) -> None:
|
||||
manifest = _fresh_manifest(plugin_id="UpperCase")
|
||||
with pytest.raises(ValueError, match="Invalid plugin id format"):
|
||||
validate_manifest(manifest)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RevenueShare
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRevenueShare:
|
||||
@pytest.fixture
|
||||
def reg(self) -> PluginRegistry:
|
||||
return PluginRegistry()
|
||||
|
||||
@pytest.fixture
|
||||
def rs(self, reg: PluginRegistry) -> RevenueShare:
|
||||
# Patch the 'registry' name as bound inside revenue_share.py
|
||||
with patch("app.marketplace.revenue_share.registry", reg):
|
||||
yield RevenueShare()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_install_free_plugin(
|
||||
self, reg: PluginRegistry, rs: RevenueShare
|
||||
) -> None:
|
||||
await rs.record_install("plugin-github-sync", "user-1", amount_cents=0)
|
||||
assert len(rs._events) == 1
|
||||
assert rs._events[0]["developer_share_cents"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_install_paid_plugin_no_stripe(
|
||||
self, reg: PluginRegistry, rs: RevenueShare
|
||||
) -> None:
|
||||
# No STRIPE_SECRET_KEY configured in test env — should not crash
|
||||
await rs.record_install("plugin-slack-notify", "user-2", amount_cents=499)
|
||||
assert len(rs._events) == 1
|
||||
assert rs._events[0]["amount_cents"] == 499
|
||||
assert rs._events[0]["developer_share_cents"] == int(499 * 0.70)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_install_increments_registry_count(
|
||||
self, reg: PluginRegistry, rs: RevenueShare
|
||||
) -> None:
|
||||
await rs.record_install("plugin-github-sync", "user-1", amount_cents=0)
|
||||
entry = await reg.get_plugin("plugin-github-sync")
|
||||
assert entry is not None
|
||||
assert entry["install_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_earnings_empty(
|
||||
self, reg: PluginRegistry, rs: RevenueShare
|
||||
) -> None:
|
||||
result = await rs.get_earnings("unknown-dev")
|
||||
assert result["total_installs"] == 0
|
||||
assert result["total_revenue_cents"] == 0
|
||||
assert result["developer_share_cents"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_earnings_aggregates(
|
||||
self, reg: PluginRegistry, rs: RevenueShare
|
||||
) -> None:
|
||||
# "Adiuva" is the author of the seeded plugins
|
||||
await rs.record_install("plugin-slack-notify", "u1", amount_cents=499)
|
||||
await rs.record_install("plugin-slack-notify", "u2", amount_cents=499)
|
||||
result = await rs.get_earnings("Adiuva")
|
||||
assert result["total_installs"] == 2
|
||||
assert result["total_revenue_cents"] == 998
|
||||
assert result["developer_share_cents"] == int(499 * 0.70) * 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Route integration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPluginRoutes:
|
||||
def test_list_plugins_requires_power_tier(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/plugins", headers=_auth("free"))
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_list_plugins_pro_tier_blocked(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/plugins", headers=_auth("pro"))
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_list_plugins_power_tier_ok(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/plugins", headers=_auth("power"))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "plugins" in data
|
||||
assert data["total"] >= 3
|
||||
|
||||
def test_list_plugins_team_tier_ok(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/plugins", headers=_auth("team"))
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_get_plugin_found(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/plugins/plugin-github-sync", headers=_auth())
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["plugin"]["id"] == "plugin-github-sync"
|
||||
assert "install_count" in data
|
||||
|
||||
def test_get_plugin_not_found(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/plugins/no-such-plugin", headers=_auth())
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_install_plugin_free(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.post(
|
||||
"/api/v1/plugins/plugin-github-sync/install",
|
||||
json={"plugin_id": "plugin-github-sync"},
|
||||
headers=_auth(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is True
|
||||
assert "download_url" in data
|
||||
|
||||
def test_install_plugin_not_found(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.post(
|
||||
"/api/v1/plugins/ghost/install",
|
||||
json={"plugin_id": "ghost"},
|
||||
headers=_auth(),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_uninstall_plugin_ok(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.delete(
|
||||
"/api/v1/plugins/plugin-github-sync/install",
|
||||
headers=_auth(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["ok"] is True
|
||||
|
||||
def test_install_requires_power_tier(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.post(
|
||||
"/api/v1/plugins/plugin-github-sync/install",
|
||||
json={"plugin_id": "plugin-github-sync"},
|
||||
headers=_auth("free"),
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
Reference in New Issue
Block a user