Compare commits
4 Commits
4c4df7335a
...
5d485b3665
| Author | SHA1 | Date | |
|---|---|---|---|
| 5d485b3665 | |||
| 9787befd4a | |||
| 8f7bc25611 | |||
| 3e07fff958 |
@@ -331,14 +331,14 @@ adiuva-api/
|
|||||||
### Step 9 — Middleware
|
### Step 9 — Middleware
|
||||||
|
|
||||||
#### 9a — Auth middleware
|
#### 9a — Auth middleware
|
||||||
- [ ] `app/api/middleware/auth.py`:
|
- [x] `app/api/middleware/auth.py`:
|
||||||
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
|
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
|
||||||
- Validates JWT signature, expiry, extracts `user_id` and `tier`
|
- Validates JWT signature, expiry, extracts `user_id` and `tier`
|
||||||
- Raises `401` on invalid/expired token
|
- Raises `401` on invalid/expired token
|
||||||
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
||||||
|
|
||||||
#### 9b — Rate limiter
|
#### 9b — Rate limiter
|
||||||
- [ ] `app/api/middleware/rate_limit.py`:
|
- [x] `app/api/middleware/rate_limit.py`:
|
||||||
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
|
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
|
||||||
- Tier-based limits:
|
- Tier-based limits:
|
||||||
- Free: 20 req/min
|
- Free: 20 req/min
|
||||||
@@ -348,7 +348,7 @@ adiuva-api/
|
|||||||
- Custom 429 response with `Retry-After` header
|
- Custom 429 response with `Retry-After` header
|
||||||
|
|
||||||
#### 9c — Sanitizer
|
#### 9c — Sanitizer
|
||||||
- [ ] `app/api/middleware/sanitizer.py`:
|
- [x] `app/api/middleware/sanitizer.py`:
|
||||||
- Response middleware that scans response bodies
|
- Response middleware that scans response bodies
|
||||||
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
|
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
|
||||||
- Pattern-based detection + exact match against known prompt fingerprints
|
- Pattern-based detection + exact match against known prompt fingerprints
|
||||||
@@ -356,33 +356,33 @@ adiuva-api/
|
|||||||
|
|
||||||
- **Outcome:** Secure, rate-limited API with prompt IP protection.
|
- **Outcome:** Secure, rate-limited API with prompt IP protection.
|
||||||
|
|
||||||
### Step 10 — Plugin Marketplace
|
### Step 10 — Plugin Marketplace ✅
|
||||||
- [ ] `app/marketplace/plugin_registry.py`:
|
- [x] `app/marketplace/plugin_registry.py`:
|
||||||
- `PluginRegistry`:
|
- `PluginRegistry`:
|
||||||
- `async list_plugins(category, query, page, sort) -> PluginListResponse`
|
- `async list_plugins(category, query, page, sort) -> PluginListResponse`
|
||||||
- `async get_plugin(plugin_id) -> PluginManifest | None`
|
- `async get_plugin(plugin_id) -> PluginManifest | None`
|
||||||
- `async submit_plugin(manifest: PluginManifest, package_s3_key: str) -> str` — returns plugin_id, sets status = 'pending_review'
|
- `async submit_plugin(manifest: PluginManifest, package_s3_key: str) -> str` — returns plugin_id, sets status = 'pending_review'
|
||||||
- `async approve_plugin(plugin_id) -> None` — admin only, sets status = 'approved'
|
- `async approve_plugin(plugin_id) -> None` — admin only, sets status = 'approved'
|
||||||
- `async reject_plugin(plugin_id, reason: str) -> None`
|
- `async reject_plugin(plugin_id, reason: str) -> None`
|
||||||
- [ ] `app/marketplace/plugin_review.py`:
|
- [x] `app/marketplace/plugin_review.py`:
|
||||||
- `ReviewQueue`:
|
- `ReviewQueue`:
|
||||||
- `async get_pending() -> list[dict]`
|
- `async get_pending() -> list[dict]`
|
||||||
- `async submit_review(plugin_id, reviewer_id, decision, notes) -> None`
|
- `async submit_review(plugin_id, reviewer_id, decision, notes) -> None`
|
||||||
- Security checklist enforced before approval: manifest schema valid, permissions are from allowed set, no binary blobs in manifest
|
- Security checklist enforced before approval: manifest schema valid, permissions are from allowed set, no binary blobs in manifest
|
||||||
- [ ] `app/marketplace/revenue_share.py`:
|
- [x] `app/marketplace/revenue_share.py`:
|
||||||
- `RevenueShare`:
|
- `RevenueShare`:
|
||||||
- `async record_install(plugin_id, user_id, amount_cents) -> None`
|
- `async record_install(plugin_id, user_id, amount_cents) -> None`
|
||||||
- `async payout_developer(plugin_id, period) -> None` — Stripe Connect transfer: 70% to developer
|
- `async payout_developer(plugin_id, period) -> None` — Stripe Connect transfer: 70% to developer
|
||||||
- `async get_earnings(developer_id, period) -> dict`
|
- `async get_earnings(developer_id, period) -> dict`
|
||||||
- **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split.
|
- **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split.
|
||||||
|
|
||||||
### Step 11 — Billing & Tier management
|
### Step 11 — Billing & Tier management ✅
|
||||||
- [ ] `app/billing/stripe_service.py`:
|
- [x] `app/billing/stripe_service.py`:
|
||||||
- `create_checkout_session(user_id, tier) -> str`
|
- `create_checkout_session(user_id, tier) -> str`
|
||||||
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
|
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
|
||||||
- `get_subscription(user_id) -> dict | None`
|
- `get_subscription(user_id) -> dict | None`
|
||||||
- `cancel_subscription(user_id) -> None`
|
- `cancel_subscription(user_id) -> None`
|
||||||
- [ ] `app/billing/tier_manager.py`:
|
- [x] `app/billing/tier_manager.py`:
|
||||||
- `TierManager`:
|
- `TierManager`:
|
||||||
- Feature matrix:
|
- Feature matrix:
|
||||||
```python
|
```python
|
||||||
@@ -433,6 +433,9 @@ adiuva-api/
|
|||||||
- `check_feature(user_id, feature) -> bool`
|
- `check_feature(user_id, feature) -> bool`
|
||||||
- `get_rate_limit(tier) -> int`
|
- `get_rate_limit(tier) -> int`
|
||||||
- `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit
|
- `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit
|
||||||
|
- [x] `app/billing/__init__.py`: exports `stripe_service` and `tier_manager` singletons
|
||||||
|
- [x] `app/api/routes/billing.py`: refactored to delegate to `StripeService`
|
||||||
|
- [x] `app/api/routes/storage.py` and `backup.py`: `_check_quota` now delegates to `tier_manager.enforce_quota` / `enforce_backup_quota`
|
||||||
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
|
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
|
||||||
|
|
||||||
### Step 12 — Database (auth/billing/marketplace only)
|
### Step 12 — Database (auth/billing/marketplace only)
|
||||||
|
|||||||
47
alembic.ini
Normal file
47
alembic.ini
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# Alembic configuration file.
|
||||||
|
# The async app uses postgresql+asyncpg:// at runtime.
|
||||||
|
# Alembic CLI uses the sync psycopg2 URL set in env.py (reads from DATABASE_URL env var).
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
script_location = alembic
|
||||||
|
prepend_sys_path = .
|
||||||
|
version_path_separator = os
|
||||||
|
|
||||||
|
# sqlalchemy.url is overridden in alembic/env.py — leave as placeholder.
|
||||||
|
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARN
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARN
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
93
alembic/env.py
Normal file
93
alembic/env.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""Alembic migration environment — async-compatible.
|
||||||
|
|
||||||
|
At runtime the app uses ``postgresql+asyncpg://``. Alembic's CLI is
|
||||||
|
synchronous, so we derive a *sync* psycopg2 URL from the same DATABASE_URL
|
||||||
|
env var by replacing the driver prefix.
|
||||||
|
|
||||||
|
Run migrations with:
|
||||||
|
alembic upgrade head
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from sqlalchemy import engine_from_config, pool
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
|
||||||
|
# Alembic Config object (gives access to alembic.ini values).
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Set up Python logging from alembic.ini.
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# Import the Base so that Alembic can detect model changes for --autogenerate.
|
||||||
|
from app.models import Base # noqa: E402
|
||||||
|
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_url(async_url: str) -> str:
|
||||||
|
"""Convert an asyncpg URL to a psycopg2 URL for Alembic CLI."""
|
||||||
|
return re.sub(r"postgresql\+asyncpg", "postgresql+psycopg2", async_url)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_url() -> str:
|
||||||
|
db_url = os.environ.get("DATABASE_URL", "")
|
||||||
|
if not db_url:
|
||||||
|
# Fall back to settings if env var not set directly.
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
db_url = settings.DATABASE_URL
|
||||||
|
return _sync_url(db_url)
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Emit SQL without a live DB connection."""
|
||||||
|
url = _get_url()
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def do_run_migrations(connection): # type: ignore[no-untyped-def]
|
||||||
|
context.configure(
|
||||||
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_migrations_online_async() -> None:
|
||||||
|
"""Run migrations against a live DB using the async engine."""
|
||||||
|
async_url = os.environ.get("DATABASE_URL", "")
|
||||||
|
if not async_url:
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
async_url = settings.DATABASE_URL
|
||||||
|
|
||||||
|
connectable = create_async_engine(async_url, poolclass=pool.NullPool)
|
||||||
|
async with connectable.connect() as connection:
|
||||||
|
await connection.run_sync(do_run_migrations)
|
||||||
|
await connectable.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
asyncio.run(run_migrations_online_async())
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
28
alembic/script.py.mako
Normal file
28
alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
202
alembic/versions/001_initial_schema.py
Normal file
202
alembic/versions/001_initial_schema.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
"""Initial schema: users, refresh_tokens, subscriptions, storage_records,
|
||||||
|
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
|
||||||
|
|
||||||
|
Revision ID: 001
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-03-02
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "001"
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ── Enum types ────────────────────────────────────────────────────────
|
||||||
|
billing_tier = postgresql.ENUM(
|
||||||
|
"free", "pro", "power", "team", name="billing_tier", create_type=False
|
||||||
|
)
|
||||||
|
plugin_status = postgresql.ENUM(
|
||||||
|
"pending_review", "approved", "rejected", name="plugin_status", create_type=False
|
||||||
|
)
|
||||||
|
review_decision = postgresql.ENUM(
|
||||||
|
"approved", "rejected", name="review_decision", create_type=False
|
||||||
|
)
|
||||||
|
for enum in (billing_tier, plugin_status, review_decision):
|
||||||
|
enum.create(op.get_bind(), checkfirst=True)
|
||||||
|
|
||||||
|
# ── users ─────────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"users",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("email", sa.String(255), nullable=False),
|
||||||
|
sa.Column("password_hash", sa.String(255), nullable=False),
|
||||||
|
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier"), nullable=False, server_default="free"),
|
||||||
|
sa.Column("stripe_customer_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("email"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_users_email", "users", ["email"])
|
||||||
|
|
||||||
|
# ── refresh_tokens ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"refresh_tokens",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("token_hash", sa.String(64), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("token_hash"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_refresh_tokens_user_id", "refresh_tokens", ["user_id"])
|
||||||
|
op.create_index("ix_refresh_tokens_token_hash", "refresh_tokens", ["token_hash"])
|
||||||
|
|
||||||
|
# ── subscriptions ─────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"subscriptions",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("stripe_subscription_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier"), nullable=False, server_default="free"),
|
||||||
|
sa.Column("status", sa.String(50), nullable=False, server_default="free"),
|
||||||
|
sa.Column("current_period_end", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("user_id"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
|
||||||
|
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
|
||||||
|
|
||||||
|
# ── storage_records ───────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"storage_records",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("table_name", sa.String(100), nullable=False),
|
||||||
|
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||||
|
sa.Column("checksum", sa.String(64), nullable=False),
|
||||||
|
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
|
||||||
|
|
||||||
|
# ── backup_metadata ───────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"backup_metadata",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||||
|
sa.Column("version", sa.Integer, nullable=False),
|
||||||
|
sa.Column("timestamp", sa.BigInteger, nullable=False),
|
||||||
|
sa.Column("checksum", sa.String(64), nullable=False),
|
||||||
|
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
|
||||||
|
|
||||||
|
# ── plugins ───────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugins",
|
||||||
|
sa.Column("id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("description", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
|
||||||
|
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||||
|
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
|
||||||
|
sa.Column("category", sa.String(100), nullable=False, server_default=""),
|
||||||
|
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status"), nullable=False, server_default="pending_review"),
|
||||||
|
sa.Column("s3_package_key", sa.String(500), nullable=True),
|
||||||
|
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
|
||||||
|
sa.Column("rejection_reason", sa.Text, nullable=True),
|
||||||
|
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── plugin_installations ──────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugin_installations",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
|
||||||
|
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
|
||||||
|
|
||||||
|
# ── plugin_reviews ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugin_reviews",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||||
|
sa.Column("decision", sa.Enum("approved", "rejected", name="review_decision"), nullable=False),
|
||||||
|
sa.Column("notes", sa.Text, nullable=True),
|
||||||
|
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
|
||||||
|
|
||||||
|
# ── revenue_events ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"revenue_events",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
|
||||||
|
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("revenue_events")
|
||||||
|
op.drop_table("plugin_reviews")
|
||||||
|
op.drop_table("plugin_installations")
|
||||||
|
op.drop_table("plugins")
|
||||||
|
op.drop_table("backup_metadata")
|
||||||
|
op.drop_table("storage_records")
|
||||||
|
op.drop_table("subscriptions")
|
||||||
|
op.drop_table("refresh_tokens")
|
||||||
|
op.drop_table("users")
|
||||||
|
|
||||||
|
op.execute("DROP TYPE IF EXISTS review_decision")
|
||||||
|
op.execute("DROP TYPE IF EXISTS plugin_status")
|
||||||
|
op.execute("DROP TYPE IF EXISTS billing_tier")
|
||||||
@@ -1,46 +1,14 @@
|
|||||||
"""Shared FastAPI dependencies.
|
"""Shared FastAPI dependencies.
|
||||||
|
|
||||||
``get_current_user`` decodes the Bearer JWT and returns a ``UserProfile``.
|
``get_current_user`` and ``oauth2_scheme`` live in ``app.api.middleware.auth``
|
||||||
Step 9 will layer rate-limiting and sanitization middleware on top of this.
|
(the canonical location per Step 9). This module re-exports them so that all
|
||||||
Step 12 will add a DB look-up to fetch the live tier from PostgreSQL.
|
existing route imports (``from app.api.deps import get_current_user``) continue
|
||||||
|
to work without modification.
|
||||||
|
|
||||||
|
Step 12 will update ``get_current_user`` to fetch the live tier from PostgreSQL
|
||||||
|
instead of reading it from the JWT payload.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from app.api.middleware.auth import get_current_user, oauth2_scheme # noqa: F401
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, status
|
__all__ = ["get_current_user", "oauth2_scheme"]
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
from jose import JWTError, jwt
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
from app.schemas import BillingTier, UserProfile
|
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
|
||||||
token: str = Depends(oauth2_scheme),
|
|
||||||
) -> UserProfile:
|
|
||||||
"""Validate a Bearer JWT and return the authenticated user.
|
|
||||||
|
|
||||||
Raises ``HTTP 401`` on any invalid or expired token.
|
|
||||||
The tier embedded in the JWT is used for feature-gating until Step 12
|
|
||||||
adds a live DB lookup.
|
|
||||||
"""
|
|
||||||
credentials_exc = HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Could not validate credentials",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(
|
|
||||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
|
||||||
)
|
|
||||||
user_id: str | None = payload.get("sub")
|
|
||||||
email: str | None = payload.get("email")
|
|
||||||
tier: str = payload.get("tier", "free")
|
|
||||||
if not user_id or not email:
|
|
||||||
raise credentials_exc
|
|
||||||
except JWTError:
|
|
||||||
raise credentials_exc
|
|
||||||
|
|
||||||
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]
|
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
"""API middleware package.
|
||||||
|
|
||||||
|
Exports the three middleware components introduced in Step 9:
|
||||||
|
- Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme``
|
||||||
|
- Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter)
|
||||||
|
- Sanitizer: ``SanitizerMiddleware``
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.api.middleware.auth import get_current_user, oauth2_scheme
|
||||||
|
from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter
|
||||||
|
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_current_user",
|
||||||
|
"oauth2_scheme",
|
||||||
|
"TierRateLimitMiddleware",
|
||||||
|
"limiter",
|
||||||
|
"SanitizerMiddleware",
|
||||||
|
]
|
||||||
|
|||||||
65
app/api/middleware/auth.py
Normal file
65
app/api/middleware/auth.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""Auth middleware — JWT validation dependency.
|
||||||
|
|
||||||
|
``get_current_user`` is the FastAPI dependency used by all protected routes.
|
||||||
|
It decodes the Bearer JWT (identity + expiry), then fetches the current tier
|
||||||
|
from the ``subscriptions`` table so that tier changes take effect immediately
|
||||||
|
without requiring token re-issue.
|
||||||
|
|
||||||
|
Exempt routes (no JWT required):
|
||||||
|
- POST /api/v1/auth/register
|
||||||
|
- POST /api/v1/auth/login
|
||||||
|
- POST /api/v1/billing/webhook
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.db import get_session
|
||||||
|
from app.schemas import UserProfile
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(
|
||||||
|
token: str = Depends(oauth2_scheme),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Validate a Bearer JWT and return the authenticated user.
|
||||||
|
|
||||||
|
The JWT is used for identity and expiry only. The tier is fetched live
|
||||||
|
from the ``subscriptions`` table so that upgrades/downgrades take effect
|
||||||
|
immediately. Falls back to ``'free'`` when no subscription row exists.
|
||||||
|
|
||||||
|
Raises HTTP 401 on any invalid or expired token.
|
||||||
|
"""
|
||||||
|
credentials_exc = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)
|
||||||
|
user_id: str | None = payload.get("sub")
|
||||||
|
email: str | None = payload.get("email")
|
||||||
|
if not user_id or not email:
|
||||||
|
raise credentials_exc
|
||||||
|
except JWTError:
|
||||||
|
raise credentials_exc
|
||||||
|
|
||||||
|
# Live tier lookup — subscription row is the authoritative source.
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
tier: str = result.scalar_one_or_none() or "free"
|
||||||
|
|
||||||
|
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]
|
||||||
129
app/api/middleware/rate_limit.py
Normal file
129
app/api/middleware/rate_limit.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""Tier-aware rate limiting middleware.
|
||||||
|
|
||||||
|
Uses a per-user sliding-window counter (in-process, no Redis required).
|
||||||
|
The ``slowapi`` Limiter is also exported for optional route-level decoration.
|
||||||
|
|
||||||
|
Limits (requests per minute):
|
||||||
|
- free: 20
|
||||||
|
- pro: 60
|
||||||
|
- power: 120
|
||||||
|
- team: 200
|
||||||
|
|
||||||
|
Exempt paths bypass the limiter entirely:
|
||||||
|
- POST /api/v1/auth/register
|
||||||
|
- POST /api/v1/auth/login
|
||||||
|
- POST /api/v1/billing/webhook
|
||||||
|
- GET /api/v1/health
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
_TIER_LIMITS: dict[str, int] = {
|
||||||
|
"free": 20,
|
||||||
|
"pro": 60,
|
||||||
|
"power": 120,
|
||||||
|
"team": 200,
|
||||||
|
}
|
||||||
|
|
||||||
|
_EXEMPT_PATHS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
"/api/v1/billing/webhook",
|
||||||
|
"/api/v1/health",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_id_from_jwt(request: Request) -> str:
|
||||||
|
"""Key function for the slowapi Limiter: returns JWT sub or remote IP."""
|
||||||
|
auth = request.headers.get("Authorization", "")
|
||||||
|
token = auth.removeprefix("Bearer ").strip()
|
||||||
|
if not token:
|
||||||
|
return get_remote_address(request)
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)
|
||||||
|
return payload.get("sub") or get_remote_address(request)
|
||||||
|
except JWTError:
|
||||||
|
return get_remote_address(request)
|
||||||
|
|
||||||
|
|
||||||
|
# Exported Limiter instance — available for optional route-level decoration.
|
||||||
|
limiter = Limiter(key_func=_get_user_id_from_jwt)
|
||||||
|
|
||||||
|
|
||||||
|
class TierRateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Sliding-window rate limiter applied globally across all non-exempt routes.
|
||||||
|
|
||||||
|
Each authenticated user gets their own 60-second window sized by tier.
|
||||||
|
Unauthenticated requests pass through (the auth dependency will reject them
|
||||||
|
with 401 before the route handler runs).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
|
super().__init__(app)
|
||||||
|
# user_id → list of request timestamps (float, seconds since epoch)
|
||||||
|
self._window: dict[str, list[float]] = defaultdict(list)
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
||||||
|
if request.url.path in _EXEMPT_PATHS:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# Extract JWT claims — if no valid token, pass through for auth dep to handle.
|
||||||
|
auth = request.headers.get("Authorization", "")
|
||||||
|
token = auth.removeprefix("Bearer ").strip()
|
||||||
|
if not token:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)
|
||||||
|
user_id: str = payload.get("sub") or get_remote_address(request)
|
||||||
|
tier: str = payload.get("tier", "free")
|
||||||
|
except JWTError:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"])
|
||||||
|
now = time.monotonic()
|
||||||
|
window_start = now - 60.0
|
||||||
|
|
||||||
|
# Slide the window: discard timestamps older than 60 seconds.
|
||||||
|
timestamps = [t for t in self._window[user_id] if t > window_start]
|
||||||
|
|
||||||
|
if len(timestamps) >= limit:
|
||||||
|
retry_after = max(1, int(60 - (now - min(timestamps))))
|
||||||
|
return Response(
|
||||||
|
content=json.dumps(
|
||||||
|
{
|
||||||
|
"detail": (
|
||||||
|
f"Rate limit exceeded ({limit} req/min for {tier} tier). "
|
||||||
|
f"Retry in {retry_after}s."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
),
|
||||||
|
status_code=429,
|
||||||
|
headers={
|
||||||
|
"Retry-After": str(retry_after),
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
timestamps.append(now)
|
||||||
|
self._window[user_id] = timestamps
|
||||||
|
return await call_next(request)
|
||||||
139
app/api/middleware/sanitizer.py
Normal file
139
app/api/middleware/sanitizer.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
"""Response sanitizer middleware.
|
||||||
|
|
||||||
|
Scans JSON responses from the /api/v1/chat endpoint and strips any fragments
|
||||||
|
that could reveal server-side prompt IP:
|
||||||
|
- System prompt openers ("You are a/an/the …")
|
||||||
|
- Agent routing metadata ("Available agents:", "intent classifier", …)
|
||||||
|
- LangChain tool schema fragments (``"type": "function"``)
|
||||||
|
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
||||||
|
- Exact-match known prompt fingerprints
|
||||||
|
|
||||||
|
Binary responses (storage blobs, backup data) are never touched — the
|
||||||
|
middleware only activates for paths under /api/v1/chat.
|
||||||
|
|
||||||
|
Any sanitisation event is logged as a WARNING with the request path and the
|
||||||
|
names of the fields that were modified.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Detection patterns — order matters: fingerprints checked first (exact),
|
||||||
|
# then compiled regexes.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_FINGERPRINTS: tuple[str, ...] = (
|
||||||
|
"You are an intent classifier",
|
||||||
|
"Respond with just the agent name",
|
||||||
|
"Summarize these agent results",
|
||||||
|
"Available agents:",
|
||||||
|
"route to:",
|
||||||
|
)
|
||||||
|
|
||||||
|
_PATTERNS: tuple[re.Pattern[str], ...] = (
|
||||||
|
re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL),
|
||||||
|
re.compile(r"Available agents\s*:", re.IGNORECASE),
|
||||||
|
re.compile(r"\bintent classifier\b", re.IGNORECASE),
|
||||||
|
re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema
|
||||||
|
re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE),
|
||||||
|
re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers
|
||||||
|
re.compile(r"route\s+to\s*:", re.IGNORECASE),
|
||||||
|
re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_text(text: str) -> tuple[str, bool]:
|
||||||
|
"""Scan *text* for prompt fragments and replace matches with ``[REDACTED]``.
|
||||||
|
|
||||||
|
Returns ``(cleaned_text, was_changed)``.
|
||||||
|
"""
|
||||||
|
# Fingerprint check — if any exact phrase is present, redact the whole string.
|
||||||
|
for fp in _FINGERPRINTS:
|
||||||
|
if fp in text:
|
||||||
|
return "[REDACTED]", True
|
||||||
|
|
||||||
|
changed = False
|
||||||
|
for pattern in _PATTERNS:
|
||||||
|
new_text, n = pattern.subn("[REDACTED]", text)
|
||||||
|
if n:
|
||||||
|
text = new_text
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
return text, changed
|
||||||
|
|
||||||
|
|
||||||
|
class SanitizerMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Strip prompt IP from /api/v1/chat JSON responses."""
|
||||||
|
|
||||||
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
|
super().__init__(app)
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
||||||
|
response: Response = await call_next(request)
|
||||||
|
|
||||||
|
# Only process chat endpoint responses.
|
||||||
|
if not request.url.path.startswith("/api/v1/chat"):
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Read body — collect streaming chunks.
|
||||||
|
body_bytes = b""
|
||||||
|
async for chunk in response.body_iterator:
|
||||||
|
body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode()
|
||||||
|
|
||||||
|
# Skip non-JSON bodies (shouldn't happen on /chat, but be safe).
|
||||||
|
try:
|
||||||
|
body = json.loads(body_bytes.decode("utf-8"))
|
||||||
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||||
|
return Response(
|
||||||
|
content=body_bytes,
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=dict(response.headers),
|
||||||
|
media_type=response.media_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(body, dict):
|
||||||
|
return Response(
|
||||||
|
content=body_bytes,
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=dict(response.headers),
|
||||||
|
media_type=response.media_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Walk top-level string fields and sanitise.
|
||||||
|
sanitised_fields: list[str] = []
|
||||||
|
for key, value in body.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
cleaned, changed = _sanitize_text(value)
|
||||||
|
if changed:
|
||||||
|
body[key] = cleaned
|
||||||
|
sanitised_fields.append(key)
|
||||||
|
|
||||||
|
if sanitised_fields:
|
||||||
|
logger.warning(
|
||||||
|
"Sanitizer redacted prompt fragments",
|
||||||
|
extra={
|
||||||
|
"path": request.url.path,
|
||||||
|
"fields": sanitised_fields,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
new_body = json.dumps(body).encode("utf-8")
|
||||||
|
headers = dict(response.headers)
|
||||||
|
headers["content-length"] = str(len(new_body))
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content=new_body,
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=headers,
|
||||||
|
media_type="application/json",
|
||||||
|
)
|
||||||
@@ -1,33 +1,36 @@
|
|||||||
"""Auth routes: register, login, refresh, me.
|
"""Auth routes: register, login, refresh, me.
|
||||||
|
|
||||||
Users and refresh tokens are kept in an in-memory dict until Step 12
|
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
||||||
migrates them to PostgreSQL.
|
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
||||||
|
SHA-256 hashes so plaintext never reaches the DB.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import RefreshToken, User
|
||||||
from app.schemas import AuthTokens, UserProfile
|
from app.schemas import AuthTokens, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
# ── In-memory stores (replaced by PostgreSQL in Step 12) ─────────────
|
|
||||||
_users: dict[str, dict[str, Any]] = {} # email → user record
|
|
||||||
_refresh_tokens: dict[str, str] = {} # plain token → user_id
|
|
||||||
|
|
||||||
|
|
||||||
# ── Internal helpers ─────────────────────────────────────────────────
|
# ── Internal helpers ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _hash_password(password: str) -> str:
|
def _hash_password(password: str) -> str:
|
||||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||||
|
|
||||||
@@ -36,30 +39,29 @@ def _verify_password(password: str, hashed: str) -> bool:
|
|||||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||||
|
|
||||||
|
|
||||||
def _make_tokens(user_id: str, email: str, tier: str) -> AuthTokens:
|
def _hash_token(plain_token: str) -> str:
|
||||||
|
"""SHA-256 of the plain refresh token string."""
|
||||||
|
return hashlib.sha256(plain_token.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
||||||
|
"""Return (signed JWT, expires_at_ms)."""
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
access_exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||||
access_payload = {
|
payload = {
|
||||||
"sub": user_id,
|
"sub": user_id,
|
||||||
"email": email,
|
"email": email,
|
||||||
"tier": tier,
|
"tier": tier,
|
||||||
"exp": access_exp,
|
"exp": exp,
|
||||||
"iat": now,
|
"iat": now,
|
||||||
}
|
}
|
||||||
access_token = jwt.encode(
|
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||||
access_payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
|
return token, exp * 1000 # ms for client
|
||||||
)
|
|
||||||
refresh_token = str(uuid.uuid4())
|
|
||||||
_refresh_tokens[refresh_token] = user_id
|
|
||||||
return AuthTokens(
|
|
||||||
access_token=access_token,
|
|
||||||
refresh_token=refresh_token,
|
|
||||||
expires_at=access_exp * 1000, # milliseconds for client
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Request bodies ────────────────────────────────────────────────────
|
# ── Request bodies ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class _RegisterRequest(BaseModel):
|
class _RegisterRequest(BaseModel):
|
||||||
email: str
|
email: str
|
||||||
password: str
|
password: str
|
||||||
@@ -76,40 +78,117 @@ class _RefreshRequest(BaseModel):
|
|||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────
|
# ── Routes ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
||||||
async def register(body: _RegisterRequest) -> AuthTokens:
|
async def register(
|
||||||
|
body: _RegisterRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
"""Create a new account and return JWT tokens."""
|
"""Create a new account and return JWT tokens."""
|
||||||
if body.email in _users:
|
existing = await db.execute(select(User).where(User.email == body.email))
|
||||||
|
if existing.scalar_one_or_none() is not None:
|
||||||
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
||||||
user_id = str(uuid.uuid4())
|
|
||||||
_users[body.email] = {
|
user = User(
|
||||||
"id": user_id,
|
id=str(uuid.uuid4()),
|
||||||
"email": body.email,
|
email=body.email,
|
||||||
"password_hash": _hash_password(body.password),
|
password_hash=_hash_password(body.password),
|
||||||
"tier": "free",
|
tier="free",
|
||||||
}
|
)
|
||||||
return _make_tokens(user_id, body.email, "free")
|
db.add(user)
|
||||||
|
await db.flush() # get user.id without committing
|
||||||
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
|
rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=AuthTokens)
|
@router.post("/login", response_model=AuthTokens)
|
||||||
async def login(body: _LoginRequest) -> AuthTokens:
|
async def login(
|
||||||
|
body: _LoginRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
"""Validate credentials and return JWT tokens."""
|
"""Validate credentials and return JWT tokens."""
|
||||||
user = _users.get(body.email)
|
result = await db.execute(select(User).where(User.email == body.email))
|
||||||
if not user or not _verify_password(body.password, user["password_hash"]):
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not _verify_password(body.password, user.password_hash):
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||||
return _make_tokens(user["id"], user["email"], user["tier"])
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
|
rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh", response_model=AuthTokens)
|
@router.post("/refresh", response_model=AuthTokens)
|
||||||
async def refresh(body: _RefreshRequest) -> AuthTokens:
|
async def refresh(
|
||||||
|
body: _RefreshRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
"""Rotate a refresh token and return a new token pair."""
|
"""Rotate a refresh token and return a new token pair."""
|
||||||
user_id = _refresh_tokens.pop(body.refresh_token, None)
|
token_hash = _hash_token(body.refresh_token)
|
||||||
if user_id is None:
|
result = await db.execute(
|
||||||
|
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||||
|
)
|
||||||
|
rt = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
||||||
user = next((u for u in _users.values() if u["id"] == user_id), None)
|
|
||||||
|
# Rotate: delete old token, issue new one.
|
||||||
|
await db.delete(rt)
|
||||||
|
|
||||||
|
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
||||||
|
user = user_result.scalar_one_or_none()
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
||||||
return _make_tokens(user["id"], user["email"], user["tier"])
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
new_rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=new_expires,
|
||||||
|
)
|
||||||
|
db.add(new_rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserProfile)
|
@router.get("/me", response_model=UserProfile)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from typing import Any
|
|||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
|
from app.billing.tier_manager import tier_manager
|
||||||
from app.schemas import BackupMetadata, UserProfile
|
from app.schemas import BackupMetadata, UserProfile
|
||||||
from app.storage.blob_store import BlobStore
|
from app.storage.blob_store import BlobStore
|
||||||
from app.storage.encryption import reject_if_tampered
|
from app.storage.encryption import reject_if_tampered
|
||||||
@@ -27,32 +28,11 @@ _blob_store = BlobStore()
|
|||||||
# In-memory backup metadata — replaced by PostgreSQL backup_metadata table in Step 12
|
# In-memory backup metadata — replaced by PostgreSQL backup_metadata table in Step 12
|
||||||
_backups: dict[str, list[dict[str, Any]]] = {} # user_id → list of backup records
|
_backups: dict[str, list[dict[str, Any]]] = {} # user_id → list of backup records
|
||||||
|
|
||||||
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
|
|
||||||
_TIER_BACKUP_LIMITS_GB: dict[str, int] = {
|
|
||||||
"free": 0,
|
|
||||||
"pro": 5,
|
|
||||||
"power": 25,
|
|
||||||
"team": -1, # unlimited
|
|
||||||
}
|
|
||||||
|
|
||||||
|
def _check_backup_quota(user_id: str, size_bytes: int) -> None:
|
||||||
def _check_backup_quota(user_id: str, tier: str, size_bytes: int) -> None:
|
|
||||||
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
|
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
|
||||||
limit_gb = _TIER_BACKUP_LIMITS_GB.get(tier, 0)
|
current = sum(b["size_bytes"] for b in _backups.get(user_id, []))
|
||||||
if limit_gb == 0:
|
tier_manager.enforce_backup_quota(user_id, current_bytes=current, additional_bytes=size_bytes)
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail="Backup is not available on the free tier",
|
|
||||||
)
|
|
||||||
if limit_gb == -1:
|
|
||||||
return # unlimited
|
|
||||||
limit_bytes = limit_gb * 1024**3
|
|
||||||
used = sum(b["size_bytes"] for b in _backups.get(user_id, []))
|
|
||||||
if used + size_bytes > limit_bytes:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Backup quota exceeded for tier '{tier}'",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("")
|
@router.put("")
|
||||||
@@ -69,7 +49,7 @@ async def upload_backup(
|
|||||||
"""
|
"""
|
||||||
blob = await request.body()
|
blob = await request.body()
|
||||||
reject_if_tampered(blob, x_backup_checksum)
|
reject_if_tampered(blob, x_backup_checksum)
|
||||||
_check_backup_quota(current_user.id, current_user.tier, len(blob))
|
_check_backup_quota(current_user.id, len(blob))
|
||||||
|
|
||||||
s3_key = await _blob_store.upload(
|
s3_key = await _blob_store.upload(
|
||||||
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
|
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
|
||||||
|
|||||||
@@ -1,44 +1,25 @@
|
|||||||
"""Billing routes: Stripe checkout, webhook, subscription management.
|
"""Billing routes: Stripe checkout, webhook, subscription management.
|
||||||
|
|
||||||
Subscription records are kept in-memory until Step 12 migrates them to
|
Business logic lives in ``app.billing.stripe_service.StripeService``.
|
||||||
PostgreSQL (subscriptions table). Stripe calls are gracefully stubbed when
|
The route layer handles HTTP concerns (request parsing, response shaping)
|
||||||
STRIPE_SECRET_KEY is not configured, allowing local development without keys.
|
and delegates everything else to the service singleton.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import stripe as stripe_lib
|
from fastapi import APIRouter, Depends, Header, Request, status
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.config.settings import settings
|
from app.billing.stripe_service import stripe_service
|
||||||
|
from app.db import get_session
|
||||||
from app.schemas import BillingTier, UserProfile
|
from app.schemas import BillingTier, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/billing", tags=["billing"])
|
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||||
|
|
||||||
# In-memory subscriptions — replaced by PostgreSQL subscriptions table in Step 12
|
|
||||||
_subscriptions: dict[str, dict[str, Any]] = {} # user_id → subscription record
|
|
||||||
|
|
||||||
_TIER_PRICE_IDS: dict[str, str] = {
|
|
||||||
"pro": "price_pro_monthly", # replace with real Stripe price IDs
|
|
||||||
"power": "price_power_monthly",
|
|
||||||
"team": "price_team_monthly",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _stripe_configured() -> bool:
|
|
||||||
return bool(settings.STRIPE_SECRET_KEY)
|
|
||||||
|
|
||||||
|
|
||||||
def _stripe() -> Any:
|
|
||||||
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
|
||||||
return stripe_lib
|
|
||||||
|
|
||||||
|
|
||||||
# ── Request bodies ─────────────────────────────────────────────────────
|
# ── Request bodies ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -57,40 +38,15 @@ async def create_checkout(
|
|||||||
|
|
||||||
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
|
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
|
||||||
"""
|
"""
|
||||||
if body.tier == "free":
|
url = stripe_service.create_checkout_session(current_user.id, body.tier)
|
||||||
raise HTTPException(
|
return {"checkout_url": url}
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Cannot create a checkout session for the free tier",
|
|
||||||
)
|
|
||||||
|
|
||||||
if _stripe_configured():
|
|
||||||
price_id = _TIER_PRICE_IDS.get(body.tier)
|
|
||||||
if not price_id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Unknown tier: {body.tier}",
|
|
||||||
)
|
|
||||||
s = _stripe()
|
|
||||||
session = s.checkout.Session.create(
|
|
||||||
payment_method_types=["card"],
|
|
||||||
mode="subscription",
|
|
||||||
line_items=[{"price": price_id, "quantity": 1}],
|
|
||||||
success_url=(
|
|
||||||
"https://app.adiuva.app/billing/success"
|
|
||||||
"?session_id={CHECKOUT_SESSION_ID}"
|
|
||||||
),
|
|
||||||
cancel_url="https://app.adiuva.app/billing/cancel",
|
|
||||||
metadata={"user_id": current_user.id, "tier": body.tier},
|
|
||||||
)
|
|
||||||
return {"checkout_url": session.url}
|
|
||||||
|
|
||||||
return {"checkout_url": "https://stripe.com/stub-checkout"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/webhook", response_model=dict)
|
@router.post("/webhook", response_model=dict)
|
||||||
async def stripe_webhook(
|
async def stripe_webhook(
|
||||||
request: Request,
|
request: Request,
|
||||||
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> dict[str, bool]:
|
) -> dict[str, bool]:
|
||||||
"""Handle Stripe webhook events.
|
"""Handle Stripe webhook events.
|
||||||
|
|
||||||
@@ -98,57 +54,17 @@ async def stripe_webhook(
|
|||||||
Returns 200 immediately when Stripe is not configured (local dev).
|
Returns 200 immediately when Stripe is not configured (local dev).
|
||||||
"""
|
"""
|
||||||
payload = await request.body()
|
payload = await request.body()
|
||||||
|
await stripe_service.handle_webhook(payload, stripe_signature, db)
|
||||||
if not _stripe_configured():
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
try:
|
|
||||||
s = _stripe()
|
|
||||||
event = s.Webhook.construct_event(
|
|
||||||
payload, stripe_signature, settings.STRIPE_WEBHOOK_SECRET
|
|
||||||
)
|
|
||||||
except stripe_lib.error.SignatureVerificationError:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Invalid Stripe signature",
|
|
||||||
)
|
|
||||||
|
|
||||||
event_type: str = event["type"]
|
|
||||||
data: dict[str, Any] = event["data"]["object"]
|
|
||||||
|
|
||||||
if event_type == "checkout.session.completed":
|
|
||||||
user_id = data.get("metadata", {}).get("user_id")
|
|
||||||
tier = data.get("metadata", {}).get("tier", "free")
|
|
||||||
sub_id = data.get("subscription")
|
|
||||||
if user_id:
|
|
||||||
_subscriptions[user_id] = {
|
|
||||||
"tier": tier,
|
|
||||||
"stripe_subscription_id": sub_id,
|
|
||||||
"status": "active",
|
|
||||||
"current_period_end": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
elif event_type == "customer.subscription.updated":
|
|
||||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, then update tier
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif event_type == "customer.subscription.deleted":
|
|
||||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif event_type == "invoice.payment_failed":
|
|
||||||
# TODO(Step12): flag subscription as past_due, notify user
|
|
||||||
pass
|
|
||||||
|
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/subscription", response_model=dict)
|
@router.get("/subscription", response_model=dict)
|
||||||
async def get_subscription(
|
async def get_subscription(
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Return the current subscription info for the authenticated user."""
|
"""Return the current subscription info for the authenticated user."""
|
||||||
sub = _subscriptions.get(current_user.id)
|
sub = await stripe_service.get_subscription(current_user.id, db)
|
||||||
if sub is None:
|
if sub is None:
|
||||||
return {
|
return {
|
||||||
"tier": current_user.tier,
|
"tier": current_user.tier,
|
||||||
@@ -159,26 +75,11 @@ async def get_subscription(
|
|||||||
return sub
|
return sub
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/subscription", response_model=dict)
|
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
|
||||||
async def cancel_subscription(
|
async def cancel_subscription(
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> dict[str, bool]:
|
) -> dict[str, bool]:
|
||||||
"""Cancel the active subscription."""
|
"""Cancel the active subscription."""
|
||||||
sub = _subscriptions.get(current_user.id)
|
await stripe_service.cancel_subscription(current_user.id, db)
|
||||||
if sub is None or not sub.get("stripe_subscription_id"):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="No active subscription found",
|
|
||||||
)
|
|
||||||
|
|
||||||
if _stripe_configured():
|
|
||||||
s = _stripe()
|
|
||||||
s.Subscription.cancel(sub["stripe_subscription_id"])
|
|
||||||
|
|
||||||
_subscriptions[current_user.id] = {
|
|
||||||
**sub,
|
|
||||||
"tier": "free",
|
|
||||||
"status": "canceled",
|
|
||||||
}
|
|
||||||
|
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
"""Plugins routes: browse and install plugins from the marketplace.
|
"""Plugins routes: browse and install plugins from the marketplace.
|
||||||
|
|
||||||
The catalog and installation records are kept in-memory as stubs.
|
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes introduced
|
||||||
Step 10 replaces these with PluginRegistry, RevenueShare, and the plugins DB table.
|
in Step 10. Step 12 will swap those services' in-memory stores for
|
||||||
|
PostgreSQL persistence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -12,49 +13,12 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.config.settings import settings
|
from app.marketplace.plugin_registry import registry
|
||||||
|
from app.marketplace.revenue_share import revenue_share
|
||||||
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
|
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/plugins", tags=["plugins"])
|
router = APIRouter(prefix="/plugins", tags=["plugins"])
|
||||||
|
|
||||||
# ── In-memory catalog (Step 10 replaces with PluginRegistry + DB) ─────
|
|
||||||
|
|
||||||
_plugin_catalog: list[PluginManifest] = [
|
|
||||||
PluginManifest(
|
|
||||||
id="plugin-github-sync",
|
|
||||||
name="GitHub Sync",
|
|
||||||
description="Sync tasks with GitHub Issues and pull requests.",
|
|
||||||
version="1.0.0",
|
|
||||||
author="Adiuva",
|
|
||||||
permissions=["read:tasks", "write:tasks"],
|
|
||||||
category="productivity",
|
|
||||||
price_cents=0,
|
|
||||||
),
|
|
||||||
PluginManifest(
|
|
||||||
id="plugin-slack-notify",
|
|
||||||
name="Slack Notifier",
|
|
||||||
description="Post task and checkpoint updates to Slack channels.",
|
|
||||||
version="1.2.0",
|
|
||||||
author="Adiuva",
|
|
||||||
permissions=["read:tasks", "read:checkpoints"],
|
|
||||||
category="communication",
|
|
||||||
price_cents=499,
|
|
||||||
),
|
|
||||||
PluginManifest(
|
|
||||||
id="plugin-time-tracker",
|
|
||||||
name="Time Tracker",
|
|
||||||
description="Track time spent on tasks with automatic reporting.",
|
|
||||||
version="0.9.1",
|
|
||||||
author="Third Party",
|
|
||||||
permissions=["read:tasks", "write:tasks"],
|
|
||||||
category="productivity",
|
|
||||||
price_cents=999,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
# plugin_id → set of user_ids who have installed it
|
|
||||||
_installations: dict[str, set[str]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tier gate ─────────────────────────────────────────────────────────
|
# ── Tier gate ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -67,43 +31,12 @@ def _require_plugin_tier(user: UserProfile) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Filter + sort helpers ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _apply_filters(
|
|
||||||
plugins: list[PluginManifest],
|
|
||||||
category: str | None,
|
|
||||||
q: str | None,
|
|
||||||
) -> list[PluginManifest]:
|
|
||||||
result = plugins
|
|
||||||
if category:
|
|
||||||
result = [p for p in result if p.category == category]
|
|
||||||
if q:
|
|
||||||
q_lower = q.lower()
|
|
||||||
result = [
|
|
||||||
p for p in result
|
|
||||||
if q_lower in p.name.lower() or q_lower in p.description.lower()
|
|
||||||
]
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_sort(
|
|
||||||
plugins: list[PluginManifest],
|
|
||||||
sort: str,
|
|
||||||
) -> list[PluginManifest]:
|
|
||||||
if sort == "installs":
|
|
||||||
return sorted(plugins, key=lambda p: len(_installations.get(p.id, set())), reverse=True)
|
|
||||||
if sort == "rating":
|
|
||||||
# Placeholder until Step 10 introduces avg_rating from DB
|
|
||||||
return sorted(plugins, key=lambda p: -p.price_cents)
|
|
||||||
return plugins # "newest" = catalog insertion order
|
|
||||||
|
|
||||||
|
|
||||||
# ── Local detail schema ────────────────────────────────────────────────
|
# ── Local detail schema ────────────────────────────────────────────────
|
||||||
|
|
||||||
class _PluginDetail(BaseModel):
|
class _PluginDetail(BaseModel):
|
||||||
plugin: PluginManifest
|
plugin: PluginManifest
|
||||||
install_count: int
|
install_count: int
|
||||||
ratings: list[Any] # Step 10 populates from plugin_reviews table
|
ratings: list[Any] # Step 12 populates from plugin_reviews table
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────
|
# ── Routes ────────────────────────────────────────────────────────────
|
||||||
@@ -118,9 +51,7 @@ async def list_plugins(
|
|||||||
) -> PluginListResponse:
|
) -> PluginListResponse:
|
||||||
"""Browse the plugin marketplace. Requires Power tier or above."""
|
"""Browse the plugin marketplace. Requires Power tier or above."""
|
||||||
_require_plugin_tier(current_user)
|
_require_plugin_tier(current_user)
|
||||||
filtered = _apply_filters(_plugin_catalog, category, q)
|
return await registry.list_plugins(category=category, query=q, page=page, sort=sort)
|
||||||
sorted_plugins = _apply_sort(filtered, sort)
|
|
||||||
return PluginListResponse(plugins=sorted_plugins, total=len(sorted_plugins), page=page)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{plugin_id}", response_model=_PluginDetail)
|
@router.get("/{plugin_id}", response_model=_PluginDetail)
|
||||||
@@ -130,13 +61,13 @@ async def get_plugin(
|
|||||||
) -> _PluginDetail:
|
) -> _PluginDetail:
|
||||||
"""Get full plugin details including install count. Requires Power tier or above."""
|
"""Get full plugin details including install count. Requires Power tier or above."""
|
||||||
_require_plugin_tier(current_user)
|
_require_plugin_tier(current_user)
|
||||||
plugin = next((p for p in _plugin_catalog if p.id == plugin_id), None)
|
entry = await registry.get_plugin(plugin_id)
|
||||||
if plugin is None:
|
if entry is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||||
return _PluginDetail(
|
return _PluginDetail(
|
||||||
plugin=plugin,
|
plugin=entry["manifest"],
|
||||||
install_count=len(_installations.get(plugin_id, set())),
|
install_count=entry["install_count"],
|
||||||
ratings=[], # Step 10 populates from plugin_reviews table
|
ratings=[], # Step 12 populates from plugin_reviews table
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -146,20 +77,21 @@ async def install_plugin(
|
|||||||
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
|
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Install a plugin. Triggers Stripe Connect for paid plugins when configured.
|
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
|
||||||
|
|
||||||
Requires Power tier or above.
|
Requires Power tier or above.
|
||||||
"""
|
"""
|
||||||
_require_plugin_tier(current_user)
|
_require_plugin_tier(current_user)
|
||||||
plugin = next((p for p in _plugin_catalog if p.id == plugin_id), None)
|
entry = await registry.get_plugin(plugin_id)
|
||||||
if plugin is None:
|
if entry is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||||
|
|
||||||
if plugin.price_cents > 0 and settings.STRIPE_SECRET_KEY:
|
await revenue_share.record_install(
|
||||||
# TODO(Step10): stripe.PaymentIntent.create with destination charge (70/30 split)
|
plugin_id=plugin_id,
|
||||||
pass
|
user_id=current_user.id,
|
||||||
|
amount_cents=entry["manifest"].price_cents,
|
||||||
|
)
|
||||||
|
|
||||||
_installations.setdefault(plugin_id, set()).add(current_user.id)
|
|
||||||
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
|
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
|
||||||
return {"ok": True, "download_url": download_url}
|
return {"ok": True, "download_url": download_url}
|
||||||
|
|
||||||
@@ -170,5 +102,5 @@ async def uninstall_plugin(
|
|||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
) -> dict[str, bool]:
|
) -> dict[str, bool]:
|
||||||
"""Unregister a plugin installation."""
|
"""Unregister a plugin installation."""
|
||||||
_installations.get(plugin_id, set()).discard(current_user.id)
|
await registry.record_uninstall(plugin_id)
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
|
from app.billing.tier_manager import tier_manager
|
||||||
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
|
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
|
||||||
from app.storage.blob_store import BlobStore
|
from app.storage.blob_store import BlobStore
|
||||||
from app.storage.encryption import reject_if_tampered
|
from app.storage.encryption import reject_if_tampered
|
||||||
@@ -25,14 +26,6 @@ _blob_store = BlobStore()
|
|||||||
# In-memory record metadata — replaced by PostgreSQL storage_records table in Step 12
|
# In-memory record metadata — replaced by PostgreSQL storage_records table in Step 12
|
||||||
_records: dict[str, dict[str, Any]] = {}
|
_records: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
|
|
||||||
_TIER_STORAGE_LIMITS_GB: dict[str, int] = {
|
|
||||||
"free": 0,
|
|
||||||
"pro": 5,
|
|
||||||
"power": 25,
|
|
||||||
"team": -1, # unlimited
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Local response schemas ─────────────────────────────────────────────
|
# ── Local response schemas ─────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -51,18 +44,10 @@ class _RecordMeta(BaseModel):
|
|||||||
|
|
||||||
# ── Helpers ────────────────────────────────────────────────────────────
|
# ── Helpers ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def _check_quota(user_id: str, tier: str, additional_bytes: int) -> None:
|
def _check_quota(user_id: str, additional_bytes: int) -> None:
|
||||||
"""Raise HTTP 402 if adding ``additional_bytes`` would exceed the tier limit."""
|
"""Raise HTTP 402 if adding ``additional_bytes`` would exceed the tier limit."""
|
||||||
limit_gb = _TIER_STORAGE_LIMITS_GB.get(tier, 0)
|
current = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id)
|
||||||
if limit_gb == -1:
|
tier_manager.enforce_quota(user_id, current_bytes=current, additional_bytes=additional_bytes)
|
||||||
return # unlimited
|
|
||||||
limit_bytes = limit_gb * 1024**3
|
|
||||||
used = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id)
|
|
||||||
if used + additional_bytes > limit_bytes:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Storage quota exceeded for tier '{tier}'",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]:
|
def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]:
|
||||||
@@ -83,7 +68,7 @@ async def create_record(
|
|||||||
) -> _CreateResponse:
|
) -> _CreateResponse:
|
||||||
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
|
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
|
||||||
reject_if_tampered(body.blob, body.checksum)
|
reject_if_tampered(body.blob, body.checksum)
|
||||||
_check_quota(current_user.id, current_user.tier, len(body.blob))
|
_check_quota(current_user.id, len(body.blob))
|
||||||
|
|
||||||
record_id = str(uuid.uuid4())
|
record_id = str(uuid.uuid4())
|
||||||
now = int(time.time() * 1000)
|
now = int(time.time() * 1000)
|
||||||
@@ -159,7 +144,7 @@ async def update_record(
|
|||||||
|
|
||||||
delta = len(body.blob) - record["size_bytes"]
|
delta = len(body.blob) - record["size_bytes"]
|
||||||
if delta > 0:
|
if delta > 0:
|
||||||
_check_quota(current_user.id, current_user.tier, delta)
|
_check_quota(current_user.id, delta)
|
||||||
|
|
||||||
s3_key = await _blob_store.upload(
|
s3_key = await _blob_store.upload(
|
||||||
current_user.id, record["table"], record_id, body.blob, body.checksum
|
current_user.id, record["table"], record_id, body.blob, body.checksum
|
||||||
|
|||||||
@@ -0,0 +1,4 @@
|
|||||||
|
from app.billing.stripe_service import stripe_service
|
||||||
|
from app.billing.tier_manager import tier_manager
|
||||||
|
|
||||||
|
__all__ = ["stripe_service", "tier_manager"]
|
||||||
|
|||||||
256
app/billing/stripe_service.py
Normal file
256
app/billing/stripe_service.py
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
"""Stripe service: checkout sessions, webhook handling, subscription management.
|
||||||
|
|
||||||
|
Subscription records are persisted in the PostgreSQL ``subscriptions`` table.
|
||||||
|
All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not
|
||||||
|
configured, enabling local development without live credentials.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import stripe as stripe_lib
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
# Stripe price IDs per tier — replace with real IDs in production .env
|
||||||
|
TIER_PRICE_IDS: dict[str, str] = {
|
||||||
|
"pro": "price_pro_monthly",
|
||||||
|
"power": "price_power_monthly",
|
||||||
|
"team": "price_team_monthly",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class StripeService:
|
||||||
|
"""Wraps all Stripe interactions and owns subscription persistence."""
|
||||||
|
|
||||||
|
# ── Internal helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _configured(self) -> bool:
|
||||||
|
return bool(settings.STRIPE_SECRET_KEY)
|
||||||
|
|
||||||
|
def _client(self) -> Any:
|
||||||
|
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||||
|
return stripe_lib
|
||||||
|
|
||||||
|
# ── Public API ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def create_checkout_session(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
tier: str,
|
||||||
|
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
||||||
|
cancel_url: str = "https://app.adiuva.app/billing/cancel",
|
||||||
|
) -> str:
|
||||||
|
"""Create a Stripe checkout session and return the URL.
|
||||||
|
|
||||||
|
Returns a stub URL when Stripe is not configured.
|
||||||
|
Raises ``HTTP 400`` for the free tier or an unknown tier.
|
||||||
|
"""
|
||||||
|
if tier == "free":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Cannot create a checkout session for the free tier",
|
||||||
|
)
|
||||||
|
|
||||||
|
price_id = TIER_PRICE_IDS.get(tier)
|
||||||
|
if not price_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unknown tier: {tier}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._configured():
|
||||||
|
return "https://stripe.com/stub-checkout"
|
||||||
|
|
||||||
|
s = self._client()
|
||||||
|
session = s.checkout.Session.create(
|
||||||
|
payment_method_types=["card"],
|
||||||
|
mode="subscription",
|
||||||
|
line_items=[{"price": price_id, "quantity": 1}],
|
||||||
|
success_url=success_url,
|
||||||
|
cancel_url=cancel_url,
|
||||||
|
metadata={"user_id": user_id, "tier": tier},
|
||||||
|
)
|
||||||
|
return session.url
|
||||||
|
|
||||||
|
async def handle_webhook(
|
||||||
|
self,
|
||||||
|
payload: bytes,
|
||||||
|
sig_header: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Process a Stripe webhook event.
|
||||||
|
|
||||||
|
Verifies the signature, then dispatches on event type.
|
||||||
|
Raises ``HTTP 400`` on signature mismatch.
|
||||||
|
No-ops when Stripe is not configured.
|
||||||
|
"""
|
||||||
|
if not self._configured():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
s = self._client()
|
||||||
|
event = s.Webhook.construct_event(
|
||||||
|
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
|
||||||
|
)
|
||||||
|
except stripe_lib.error.SignatureVerificationError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid Stripe signature",
|
||||||
|
)
|
||||||
|
|
||||||
|
event_type: str = event["type"]
|
||||||
|
data: dict[str, Any] = event["data"]["object"]
|
||||||
|
|
||||||
|
if event_type == "checkout.session.completed":
|
||||||
|
user_id = data.get("metadata", {}).get("user_id")
|
||||||
|
tier = data.get("metadata", {}).get("tier", "free")
|
||||||
|
sub_id = data.get("subscription")
|
||||||
|
period_end_ts = data.get("current_period_end")
|
||||||
|
period_end = (
|
||||||
|
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||||
|
if period_end_ts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if user_id:
|
||||||
|
await self._upsert_subscription(
|
||||||
|
db, user_id, sub_id, tier, "active", period_end
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "customer.subscription.updated":
|
||||||
|
sub_id = data.get("id")
|
||||||
|
new_status = data.get("status", "active")
|
||||||
|
period_end_ts = data.get("current_period_end")
|
||||||
|
period_end = (
|
||||||
|
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||||
|
if period_end_ts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if sub_id:
|
||||||
|
await self._update_subscription_by_stripe_id(
|
||||||
|
db, sub_id, status=new_status, current_period_end=period_end
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "customer.subscription.deleted":
|
||||||
|
sub_id = data.get("id")
|
||||||
|
if sub_id:
|
||||||
|
await self._update_subscription_by_stripe_id(
|
||||||
|
db, sub_id, tier="free", status="canceled"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "invoice.payment_failed":
|
||||||
|
sub_id = data.get("subscription")
|
||||||
|
if sub_id:
|
||||||
|
await self._update_subscription_by_stripe_id(
|
||||||
|
db, sub_id, status="past_due"
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def get_subscription(
|
||||||
|
self, user_id: str, db: AsyncSession
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Return the subscription record for ``user_id``, or ``None`` if absent."""
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"tier": sub.tier,
|
||||||
|
"stripe_subscription_id": sub.stripe_subscription_id,
|
||||||
|
"status": sub.status,
|
||||||
|
"current_period_end": (
|
||||||
|
int(sub.current_period_end.timestamp() * 1000)
|
||||||
|
if sub.current_period_end
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
|
||||||
|
"""Cancel the user's Stripe subscription and downgrade them to free.
|
||||||
|
|
||||||
|
Raises ``HTTP 404`` when no active subscription exists.
|
||||||
|
"""
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None or not sub.stripe_subscription_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="No active subscription found",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._configured():
|
||||||
|
s = self._client()
|
||||||
|
s.Subscription.cancel(sub.stripe_subscription_id)
|
||||||
|
|
||||||
|
sub.tier = "free"
|
||||||
|
sub.status = "canceled"
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# ── Private DB helpers ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _upsert_subscription(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
stripe_subscription_id: str | None,
|
||||||
|
tier: str,
|
||||||
|
sub_status: str,
|
||||||
|
current_period_end: datetime | None,
|
||||||
|
) -> None:
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
sub = Subscription(user_id=user_id)
|
||||||
|
db.add(sub)
|
||||||
|
sub.stripe_subscription_id = stripe_subscription_id
|
||||||
|
sub.tier = tier
|
||||||
|
sub.status = sub_status
|
||||||
|
sub.current_period_end = current_period_end
|
||||||
|
|
||||||
|
async def _update_subscription_by_stripe_id(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
stripe_subscription_id: str,
|
||||||
|
*,
|
||||||
|
tier: str | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
current_period_end: datetime | None = None,
|
||||||
|
) -> None:
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(
|
||||||
|
Subscription.stripe_subscription_id == stripe_subscription_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
return
|
||||||
|
if tier is not None:
|
||||||
|
sub.tier = tier
|
||||||
|
if status is not None:
|
||||||
|
sub.status = status
|
||||||
|
if current_period_end is not None:
|
||||||
|
sub.current_period_end = current_period_end
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton shared across the app.
|
||||||
|
stripe_service = StripeService()
|
||||||
189
app/billing/tier_manager.py
Normal file
189
app/billing/tier_manager.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""Tier manager: feature matrix and quota enforcement.
|
||||||
|
|
||||||
|
``TierManager`` is the single source of truth for what each billing tier
|
||||||
|
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
|
||||||
|
Quota-enforcement helpers take ``tier`` directly — the caller already has it
|
||||||
|
from ``current_user.tier`` (provided by ``get_current_user``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.schemas import BillingTier
|
||||||
|
|
||||||
|
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
|
||||||
|
FEATURES: dict[str, dict[str, Any]] = {
|
||||||
|
"free": {
|
||||||
|
"agents": 3,
|
||||||
|
"batch_active": 2,
|
||||||
|
"cloud_storage_gb": 0,
|
||||||
|
"backup_gb": 0,
|
||||||
|
"providers": 1,
|
||||||
|
"batch_builder": False,
|
||||||
|
"plugin_marketplace": False,
|
||||||
|
"sso": False,
|
||||||
|
},
|
||||||
|
"pro": {
|
||||||
|
"agents": -1, # unlimited
|
||||||
|
"batch_active": 10,
|
||||||
|
"cloud_storage_gb": 5,
|
||||||
|
"backup_gb": 5,
|
||||||
|
"providers": -1,
|
||||||
|
"batch_builder": False,
|
||||||
|
"plugin_marketplace": False,
|
||||||
|
"sso": False,
|
||||||
|
},
|
||||||
|
"power": {
|
||||||
|
"agents": -1,
|
||||||
|
"batch_active": -1, # unlimited
|
||||||
|
"cloud_storage_gb": 25,
|
||||||
|
"backup_gb": 25,
|
||||||
|
"providers": -1,
|
||||||
|
"batch_builder": True,
|
||||||
|
"plugin_marketplace": True,
|
||||||
|
"sso": False,
|
||||||
|
},
|
||||||
|
"team": {
|
||||||
|
"agents": -1,
|
||||||
|
"batch_active": -1,
|
||||||
|
"cloud_storage_gb": -1, # unlimited
|
||||||
|
"backup_gb": -1, # unlimited
|
||||||
|
"providers": -1,
|
||||||
|
"batch_builder": True,
|
||||||
|
"plugin_marketplace": True,
|
||||||
|
"sso": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Requests-per-minute limit per tier.
|
||||||
|
RATE_LIMITS: dict[str, int] = {
|
||||||
|
"free": 20,
|
||||||
|
"pro": 60,
|
||||||
|
"power": 120,
|
||||||
|
"team": 200,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TierManager:
|
||||||
|
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
|
||||||
|
|
||||||
|
# ── Tier lookup ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||||
|
"""Return the current billing tier for ``user_id`` from the DB.
|
||||||
|
|
||||||
|
Falls back to ``'free'`` when no subscription row exists.
|
||||||
|
"""
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
tier: str | None = result.scalar_one_or_none()
|
||||||
|
if tier is None or tier not in FEATURES:
|
||||||
|
return "free"
|
||||||
|
return tier # type: ignore[return-value]
|
||||||
|
|
||||||
|
# ── Feature access ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def check_feature(self, tier: BillingTier, feature: str) -> bool:
|
||||||
|
"""Return ``True`` if ``tier`` has ``feature`` enabled.
|
||||||
|
|
||||||
|
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
|
||||||
|
"""
|
||||||
|
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||||
|
if value is None:
|
||||||
|
return False
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
return value != 0
|
||||||
|
|
||||||
|
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
|
||||||
|
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
|
||||||
|
if not self.check_feature(tier, feature):
|
||||||
|
detail = (
|
||||||
|
f"Feature '{feature}' requires {tier_name} tier or above."
|
||||||
|
if tier_name
|
||||||
|
else f"Feature '{feature}' is not available on your current tier."
|
||||||
|
)
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||||
|
|
||||||
|
# ── Rate limiting ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_rate_limit(self, tier: BillingTier) -> int:
|
||||||
|
"""Return the requests-per-minute limit for ``tier``."""
|
||||||
|
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
||||||
|
|
||||||
|
# ── Storage quota ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def enforce_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
|
||||||
|
|
||||||
|
``tier`` is the caller's current tier (from ``current_user.tier``).
|
||||||
|
``current_bytes`` is the total bytes already stored (queried by caller).
|
||||||
|
"""
|
||||||
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Cloud storage is not available on the '{tier}' tier",
|
||||||
|
)
|
||||||
|
if limit_gb == -1:
|
||||||
|
return # unlimited
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
if current_bytes + additional_bytes > limit_bytes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Storage quota exceeded for tier '{tier}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def enforce_backup_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
|
||||||
|
limit_gb: int = FEATURES[tier]["backup_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Backup is not available on the '{tier}' tier",
|
||||||
|
)
|
||||||
|
if limit_gb == -1:
|
||||||
|
return # unlimited
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
if current_bytes + additional_bytes > limit_bytes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Backup quota exceeded for tier '{tier}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> bool:
|
||||||
|
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
|
||||||
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
return False
|
||||||
|
if limit_gb == -1:
|
||||||
|
return True
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
return current_bytes + additional_bytes <= limit_bytes
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton shared across the app.
|
||||||
|
tier_manager = TierManager()
|
||||||
40
app/db.py
Normal file
40
app/db.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""Database engine, session factory, and base model.
|
||||||
|
|
||||||
|
All app code uses the async SQLAlchemy API. Alembic migrations use the
|
||||||
|
synchronous psycopg2 URL for the CLI (see alembic/env.py).
|
||||||
|
|
||||||
|
Usage in routes:
|
||||||
|
from app.db import get_session
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
async def my_route(db: AsyncSession = Depends(get_session)):
|
||||||
|
result = await db.execute(select(User).where(User.email == email))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
settings.DATABASE_URL,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
echo=settings.ENV == "dev",
|
||||||
|
)
|
||||||
|
|
||||||
|
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
"""Shared declarative base for all ORM models."""
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""FastAPI dependency that yields an async DB session per request."""
|
||||||
|
async with async_session() as session:
|
||||||
|
yield session
|
||||||
11
app/main.py
11
app/main.py
@@ -3,6 +3,8 @@ from contextlib import asynccontextmanager
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
||||||
|
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
@@ -14,7 +16,9 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Shutdown: nothing to clean up for now
|
# Shutdown: dispose SQLAlchemy connection pool
|
||||||
|
from app.db import engine
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
@@ -33,6 +37,11 @@ def create_app() -> FastAPI:
|
|||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
# Middleware stack (Starlette inserts at position 0, so last-added = outermost).
|
||||||
|
# Request flow: TierRateLimit → Sanitizer → CORS → Router
|
||||||
|
# Response flow: Router → CORS → Sanitizer → TierRateLimit
|
||||||
|
app.add_middleware(SanitizerMiddleware)
|
||||||
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
from app.api.routes import auth, backup, billing, chat, plans, plugins, storage, vectors
|
from app.api.routes import auth, backup, billing, chat, plans, plugins, storage, vectors
|
||||||
|
|
||||||
|
|||||||
7
app/marketplace/__init__.py
Normal file
7
app/marketplace/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""Plugin marketplace package.
|
||||||
|
|
||||||
|
Three service classes introduced in Step 10:
|
||||||
|
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
|
||||||
|
- ``ReviewQueue`` — approval workflow + security checklist
|
||||||
|
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
|
||||||
|
"""
|
||||||
211
app/marketplace/plugin_registry.py
Normal file
211
app/marketplace/plugin_registry.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""Plugin catalog registry.
|
||||||
|
|
||||||
|
Maintains the authoritative list of plugins, their review status, and
|
||||||
|
aggregate install counts. Storage is in-memory until Step 12 migrates to
|
||||||
|
the ``plugins`` PostgreSQL table.
|
||||||
|
|
||||||
|
Module-level singleton::
|
||||||
|
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from app.schemas import PluginListResponse, PluginManifest
|
||||||
|
|
||||||
|
# ── Pre-seeded approved plugins (mirrors the Step 8 stub catalog) ─────
|
||||||
|
|
||||||
|
_SEED_PLUGINS: list[dict[str, Any]] = [
|
||||||
|
{
|
||||||
|
"manifest": PluginManifest(
|
||||||
|
id="plugin-github-sync",
|
||||||
|
name="GitHub Sync",
|
||||||
|
description="Sync tasks with GitHub Issues and pull requests.",
|
||||||
|
version="1.0.0",
|
||||||
|
author="Adiuva",
|
||||||
|
permissions=["read:tasks", "write:tasks"],
|
||||||
|
category="productivity",
|
||||||
|
price_cents=0,
|
||||||
|
),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
"rejection_reason": None,
|
||||||
|
"submitted_at": int(time.time()),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"manifest": PluginManifest(
|
||||||
|
id="plugin-slack-notify",
|
||||||
|
name="Slack Notifier",
|
||||||
|
description="Post task and checkpoint updates to Slack channels.",
|
||||||
|
version="1.2.0",
|
||||||
|
author="Adiuva",
|
||||||
|
permissions=["read:tasks", "read:checkpoints"],
|
||||||
|
category="communication",
|
||||||
|
price_cents=499,
|
||||||
|
),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
"rejection_reason": None,
|
||||||
|
"submitted_at": int(time.time()),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"manifest": PluginManifest(
|
||||||
|
id="plugin-time-tracker",
|
||||||
|
name="Time Tracker",
|
||||||
|
description="Track time spent on tasks with automatic reporting.",
|
||||||
|
version="0.9.1",
|
||||||
|
author="Third Party",
|
||||||
|
permissions=["read:tasks", "write:tasks"],
|
||||||
|
category="productivity",
|
||||||
|
price_cents=999,
|
||||||
|
),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
"rejection_reason": None,
|
||||||
|
"submitted_at": int(time.time()),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
_PAGE_SIZE = 20
|
||||||
|
|
||||||
|
|
||||||
|
class PluginRegistry:
|
||||||
|
"""In-process plugin catalog.
|
||||||
|
|
||||||
|
All mutating methods are ``async`` to make the future DB swap transparent
|
||||||
|
to callers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# plugin_id → entry dict (deep-copied so each instance is independent)
|
||||||
|
self._catalog: dict[str, dict[str, Any]] = {
|
||||||
|
e["manifest"].id: copy.deepcopy(e) for e in _SEED_PLUGINS
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Queries ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def list_plugins(
|
||||||
|
self,
|
||||||
|
category: str | None = None,
|
||||||
|
query: str | None = None,
|
||||||
|
page: int = 1,
|
||||||
|
sort: Literal["rating", "installs", "newest"] = "newest",
|
||||||
|
) -> PluginListResponse:
|
||||||
|
"""Return a page of approved plugins, optionally filtered and sorted."""
|
||||||
|
entries = [e for e in self._catalog.values() if e["status"] == "approved"]
|
||||||
|
|
||||||
|
if category:
|
||||||
|
entries = [e for e in entries if e["manifest"].category == category]
|
||||||
|
|
||||||
|
if query:
|
||||||
|
q_lower = query.lower()
|
||||||
|
entries = [
|
||||||
|
e
|
||||||
|
for e in entries
|
||||||
|
if q_lower in e["manifest"].name.lower()
|
||||||
|
or q_lower in e["manifest"].description.lower()
|
||||||
|
]
|
||||||
|
|
||||||
|
if sort == "installs":
|
||||||
|
entries = sorted(entries, key=lambda e: e["install_count"], reverse=True)
|
||||||
|
elif sort == "rating":
|
||||||
|
entries = sorted(entries, key=lambda e: e["avg_rating"], reverse=True)
|
||||||
|
# "newest" = catalog insertion order (dict preserves insertion in Python 3.7+)
|
||||||
|
|
||||||
|
total = len(entries)
|
||||||
|
start = (page - 1) * _PAGE_SIZE
|
||||||
|
page_entries = entries[start : start + _PAGE_SIZE]
|
||||||
|
|
||||||
|
return PluginListResponse(
|
||||||
|
plugins=[e["manifest"] for e in page_entries],
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_plugin(self, plugin_id: str) -> dict[str, Any] | None:
|
||||||
|
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
|
||||||
|
entry = self._catalog.get(plugin_id)
|
||||||
|
if entry is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"manifest": entry["manifest"],
|
||||||
|
"status": entry["status"],
|
||||||
|
"install_count": entry["install_count"],
|
||||||
|
"avg_rating": entry["avg_rating"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Mutations ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def submit_plugin(
|
||||||
|
self,
|
||||||
|
manifest: PluginManifest,
|
||||||
|
package_s3_key: str,
|
||||||
|
) -> str:
|
||||||
|
"""Add *manifest* to the catalog with ``status='pending_review'``.
|
||||||
|
|
||||||
|
Returns the plugin_id. If a plugin with the same id already exists
|
||||||
|
it is overwritten (re-submission after rejection).
|
||||||
|
"""
|
||||||
|
plugin_id = manifest.id or str(uuid.uuid4())
|
||||||
|
self._catalog[plugin_id] = {
|
||||||
|
"manifest": manifest,
|
||||||
|
"status": "pending_review",
|
||||||
|
"s3_package_key": package_s3_key,
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
"rejection_reason": None,
|
||||||
|
"submitted_at": int(time.time()),
|
||||||
|
}
|
||||||
|
return plugin_id
|
||||||
|
|
||||||
|
async def approve_plugin(self, plugin_id: str) -> None:
|
||||||
|
"""Set *plugin_id* status to ``'approved'``.
|
||||||
|
|
||||||
|
Raises ``KeyError`` if the plugin is not found.
|
||||||
|
"""
|
||||||
|
if plugin_id not in self._catalog:
|
||||||
|
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||||
|
self._catalog[plugin_id]["status"] = "approved"
|
||||||
|
self._catalog[plugin_id]["rejection_reason"] = None
|
||||||
|
|
||||||
|
async def reject_plugin(self, plugin_id: str, reason: str) -> None:
|
||||||
|
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
|
||||||
|
|
||||||
|
Raises ``KeyError`` if the plugin is not found.
|
||||||
|
"""
|
||||||
|
if plugin_id not in self._catalog:
|
||||||
|
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||||
|
self._catalog[plugin_id]["status"] = "rejected"
|
||||||
|
self._catalog[plugin_id]["rejection_reason"] = reason
|
||||||
|
|
||||||
|
async def record_install(self, plugin_id: str) -> None:
|
||||||
|
"""Increment the install count for *plugin_id* (no-op if not found)."""
|
||||||
|
if plugin_id in self._catalog:
|
||||||
|
self._catalog[plugin_id]["install_count"] += 1
|
||||||
|
|
||||||
|
async def record_uninstall(self, plugin_id: str) -> None:
|
||||||
|
"""Decrement the install count for *plugin_id*, floored at 0."""
|
||||||
|
if plugin_id in self._catalog:
|
||||||
|
current = self._catalog[plugin_id]["install_count"]
|
||||||
|
self._catalog[plugin_id]["install_count"] = max(0, current - 1)
|
||||||
|
|
||||||
|
# ── Internal helpers used by ReviewQueue ─────────────────────────
|
||||||
|
|
||||||
|
def _get_pending_entries(self) -> list[dict[str, Any]]:
|
||||||
|
"""Return all entries with status='pending_review' (synchronous helper)."""
|
||||||
|
return [e for e in self._catalog.values() if e["status"] == "pending_review"]
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
registry = PluginRegistry()
|
||||||
127
app/marketplace/plugin_review.py
Normal file
127
app/marketplace/plugin_review.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""Plugin review workflow.
|
||||||
|
|
||||||
|
Manages the approval queue for newly submitted plugins and enforces a
|
||||||
|
security checklist before any plugin is made visible in the marketplace.
|
||||||
|
|
||||||
|
Module-level singleton::
|
||||||
|
|
||||||
|
from app.marketplace.plugin_review import review_queue
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
from app.schemas import PluginManifest
|
||||||
|
|
||||||
|
# ── Security policy ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"read:tasks",
|
||||||
|
"write:tasks",
|
||||||
|
"read:projects",
|
||||||
|
"write:projects",
|
||||||
|
"read:notes",
|
||||||
|
"write:notes",
|
||||||
|
"read:checkpoints",
|
||||||
|
"write:checkpoints",
|
||||||
|
"read:calendar",
|
||||||
|
"write:calendar",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_manifest(manifest: PluginManifest) -> None:
|
||||||
|
"""Enforce the plugin security checklist.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
``ValueError`` on the first violation found. Callers should catch
|
||||||
|
this and return HTTP 422 / reject the submission.
|
||||||
|
|
||||||
|
Checks:
|
||||||
|
1. Plugin id matches ``^[a-z0-9-]+$``
|
||||||
|
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
|
||||||
|
3. No manifest field contains raw binary data
|
||||||
|
"""
|
||||||
|
if not _PLUGIN_ID_RE.match(manifest.id):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid plugin id format: '{manifest.id}'. "
|
||||||
|
"Only lowercase letters, digits, and hyphens are allowed."
|
||||||
|
)
|
||||||
|
|
||||||
|
for perm in manifest.permissions:
|
||||||
|
if perm not in ALLOWED_PERMISSIONS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown permission: '{perm}'. "
|
||||||
|
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for field_name, value in manifest.model_dump().items():
|
||||||
|
if isinstance(value, (bytes, bytearray)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Binary content is not allowed in manifest field '{field_name}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReviewQueue:
|
||||||
|
"""Approval queue for pending plugin submissions.
|
||||||
|
|
||||||
|
Delegates status changes to the shared ``PluginRegistry`` singleton so
|
||||||
|
there is a single source of truth for plugin state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# Completed reviews — Step 12 stores in plugin_reviews table
|
||||||
|
self._reviews: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def get_pending(self) -> list[dict[str, Any]]:
|
||||||
|
"""Return all plugins currently awaiting review.
|
||||||
|
|
||||||
|
Each item is ``{plugin_id, manifest, submitted_at}``.
|
||||||
|
"""
|
||||||
|
entries = registry._get_pending_entries()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"plugin_id": e["manifest"].id,
|
||||||
|
"manifest": e["manifest"],
|
||||||
|
"submitted_at": e["submitted_at"],
|
||||||
|
}
|
||||||
|
for e in entries
|
||||||
|
]
|
||||||
|
|
||||||
|
async def submit_review(
|
||||||
|
self,
|
||||||
|
plugin_id: str,
|
||||||
|
reviewer_id: str,
|
||||||
|
decision: Literal["approved", "rejected"],
|
||||||
|
notes: str = "",
|
||||||
|
) -> None:
|
||||||
|
"""Record a review decision and update the plugin's status.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
``KeyError`` if *plugin_id* is not found in the registry.
|
||||||
|
"""
|
||||||
|
if decision == "approved":
|
||||||
|
await registry.approve_plugin(plugin_id)
|
||||||
|
else:
|
||||||
|
await registry.reject_plugin(plugin_id, reason=notes)
|
||||||
|
|
||||||
|
self._reviews.append(
|
||||||
|
{
|
||||||
|
"plugin_id": plugin_id,
|
||||||
|
"reviewer_id": reviewer_id,
|
||||||
|
"decision": decision,
|
||||||
|
"notes": notes,
|
||||||
|
"reviewed_at": int(time.time()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
review_queue = ReviewQueue()
|
||||||
205
app/marketplace/revenue_share.py
Normal file
205
app/marketplace/revenue_share.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""Revenue share tracking and Stripe Connect payouts.
|
||||||
|
|
||||||
|
Records every plugin installation as a revenue event and facilitates
|
||||||
|
70 % / 30 % payouts to developers via Stripe Connect. Storage is
|
||||||
|
in-memory until Step 12 migrates to the ``revenue_events`` table.
|
||||||
|
|
||||||
|
Module-level singleton::
|
||||||
|
|
||||||
|
from app.marketplace.revenue_share import revenue_share
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import stripe as stripe_lib
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Revenue split constants ───────────────────────────────────────────
|
||||||
|
|
||||||
|
DEVELOPER_SHARE: float = 0.70
|
||||||
|
PLATFORM_SHARE: float = 0.30
|
||||||
|
|
||||||
|
|
||||||
|
class RevenueShare:
|
||||||
|
"""Records installation revenue events and coordinates developer payouts.
|
||||||
|
|
||||||
|
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
|
||||||
|
is not configured, consistent with the rest of the billing layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# Step 12 replaces with revenue_events DB table
|
||||||
|
self._events: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _stripe_configured() -> bool:
|
||||||
|
return bool(settings.STRIPE_SECRET_KEY)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _stripe() -> Any:
|
||||||
|
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||||
|
return stripe_lib
|
||||||
|
|
||||||
|
# ── Core operations ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def record_install(
|
||||||
|
self,
|
||||||
|
plugin_id: str,
|
||||||
|
user_id: str,
|
||||||
|
amount_cents: int,
|
||||||
|
) -> None:
|
||||||
|
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
|
||||||
|
|
||||||
|
For free plugins (``amount_cents == 0``) no payment is initiated but
|
||||||
|
the event is still recorded for analytics.
|
||||||
|
|
||||||
|
For paid plugins the developer receives 70 % via a Stripe Connect
|
||||||
|
destination charge. If Stripe is not configured or the charge fails
|
||||||
|
the installation still succeeds (the event is recorded and the install
|
||||||
|
count is incremented) — a warning is logged for monitoring.
|
||||||
|
"""
|
||||||
|
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
|
||||||
|
stripe_transfer_id: str | None = None
|
||||||
|
|
||||||
|
if amount_cents > 0 and self._stripe_configured():
|
||||||
|
plugin_entry = registry._catalog.get(plugin_id)
|
||||||
|
developer_stripe_account: str | None = None
|
||||||
|
if plugin_entry:
|
||||||
|
# Step 12: look up developer's Stripe account from DB
|
||||||
|
# For now, the author field is used as a placeholder key.
|
||||||
|
developer_stripe_account = None # no real account yet
|
||||||
|
|
||||||
|
if developer_stripe_account:
|
||||||
|
try:
|
||||||
|
s = self._stripe()
|
||||||
|
transfer = s.Transfer.create(
|
||||||
|
amount=developer_share_cents,
|
||||||
|
currency="eur",
|
||||||
|
destination=developer_stripe_account,
|
||||||
|
description=f"Revenue share for plugin {plugin_id}",
|
||||||
|
metadata={"plugin_id": plugin_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
stripe_transfer_id = transfer["id"]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Stripe Connect transfer failed for plugin %s: %s",
|
||||||
|
plugin_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"No Stripe account on file for plugin %s developer; "
|
||||||
|
"skipping transfer.",
|
||||||
|
plugin_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._events.append(
|
||||||
|
{
|
||||||
|
"plugin_id": plugin_id,
|
||||||
|
"user_id": user_id,
|
||||||
|
"amount_cents": amount_cents,
|
||||||
|
"developer_share_cents": developer_share_cents,
|
||||||
|
"stripe_transfer_id": stripe_transfer_id,
|
||||||
|
"paid_at": None,
|
||||||
|
"created_at": int(time.time()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await registry.record_install(plugin_id)
|
||||||
|
|
||||||
|
async def get_earnings(
|
||||||
|
self,
|
||||||
|
developer_id: str,
|
||||||
|
period: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return aggregated earnings for *developer_id*.
|
||||||
|
|
||||||
|
``period`` is an optional ``YYYY-MM`` string to restrict the window.
|
||||||
|
|
||||||
|
Returns::
|
||||||
|
|
||||||
|
{
|
||||||
|
"developer_id": str,
|
||||||
|
"period": str | None,
|
||||||
|
"total_installs": int,
|
||||||
|
"total_revenue_cents": int,
|
||||||
|
"developer_share_cents": int,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Find plugin ids belonging to this developer
|
||||||
|
developer_plugin_ids: set[str] = {
|
||||||
|
pid
|
||||||
|
for pid, entry in registry._catalog.items()
|
||||||
|
if entry["manifest"].author == developer_id
|
||||||
|
}
|
||||||
|
|
||||||
|
events = [e for e in self._events if e["plugin_id"] in developer_plugin_ids]
|
||||||
|
|
||||||
|
if period:
|
||||||
|
# Filter by YYYY-MM prefix of the created_at timestamp
|
||||||
|
events = [
|
||||||
|
e
|
||||||
|
for e in events
|
||||||
|
if time.strftime("%Y-%m", time.gmtime(e["created_at"])) == period
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"developer_id": developer_id,
|
||||||
|
"period": period,
|
||||||
|
"total_installs": len(events),
|
||||||
|
"total_revenue_cents": sum(e["amount_cents"] for e in events),
|
||||||
|
"developer_share_cents": sum(e["developer_share_cents"] for e in events),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def payout_developer(self, plugin_id: str, period: str) -> None:
|
||||||
|
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
|
||||||
|
|
||||||
|
Marks processed events with ``paid_at`` timestamp.
|
||||||
|
Stubs gracefully when Stripe is not configured.
|
||||||
|
"""
|
||||||
|
unpaid = [
|
||||||
|
e
|
||||||
|
for e in self._events
|
||||||
|
if e["plugin_id"] == plugin_id
|
||||||
|
and e["paid_at"] is None
|
||||||
|
and time.strftime("%Y-%m", time.gmtime(e["created_at"])) == period
|
||||||
|
]
|
||||||
|
|
||||||
|
total_dev_share = sum(e["developer_share_cents"] for e in unpaid)
|
||||||
|
if total_dev_share <= 0 or not unpaid:
|
||||||
|
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._stripe_configured():
|
||||||
|
plugin_entry = registry._catalog.get(plugin_id)
|
||||||
|
developer_stripe_account: str | None = None # Step 12: fetch from DB
|
||||||
|
if plugin_entry and developer_stripe_account:
|
||||||
|
try:
|
||||||
|
s = self._stripe()
|
||||||
|
s.Transfer.create(
|
||||||
|
amount=total_dev_share,
|
||||||
|
currency="eur",
|
||||||
|
destination=developer_stripe_account,
|
||||||
|
description=f"Payout for plugin {plugin_id} period {period}",
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
paid_ts = int(time.time())
|
||||||
|
for event in unpaid:
|
||||||
|
event["paid_at"] = paid_ts
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
revenue_share = RevenueShare()
|
||||||
269
app/models.py
Normal file
269
app/models.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
"""SQLAlchemy ORM models for all persistent tables.
|
||||||
|
|
||||||
|
Only auth, billing, storage metadata, and marketplace data live here.
|
||||||
|
User content (notes, tasks, etc.) is NEVER persisted server-side —
|
||||||
|
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
|
||||||
|
|
||||||
|
Table inventory:
|
||||||
|
users — account credentials + tier
|
||||||
|
refresh_tokens — hashed refresh token store
|
||||||
|
subscriptions — Stripe subscription records
|
||||||
|
storage_records — S3 blob metadata (no plaintext)
|
||||||
|
backup_metadata — encrypted backup manifests
|
||||||
|
plugins — marketplace plugin catalog
|
||||||
|
plugin_installations — per-user install records
|
||||||
|
plugin_reviews — admin review decisions
|
||||||
|
revenue_events — Stripe Connect 70/30 split ledger
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
BigInteger,
|
||||||
|
Boolean,
|
||||||
|
DateTime,
|
||||||
|
Enum,
|
||||||
|
Float,
|
||||||
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
|
func,
|
||||||
|
)
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.db import Base
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _uuid() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def _now() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Enum types ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
||||||
|
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
|
||||||
|
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Models ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||||
|
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
|
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_tokens: Mapped[list[RefreshToken]] = relationship(
|
||||||
|
back_populates="user", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
subscription: Mapped[Subscription | None] = relationship(
|
||||||
|
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshToken(Base):
|
||||||
|
__tablename__ = "refresh_tokens"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||||
|
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
||||||
|
|
||||||
|
|
||||||
|
class Subscription(Base):
|
||||||
|
__tablename__ = "subscriptions"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, unique=True, index=True
|
||||||
|
)
|
||||||
|
stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
|
||||||
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
|
status: Mapped[str] = mapped_column(String(50), nullable=False, default="free")
|
||||||
|
current_period_end: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
user: Mapped[User] = relationship(back_populates="subscription")
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRecord(Base):
|
||||||
|
__tablename__ = "storage_records"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BackupMetadata(Base):
|
||||||
|
__tablename__ = "backup_metadata"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||||
|
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin(Base):
|
||||||
|
__tablename__ = "plugins"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
|
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
|
||||||
|
# nullable until developer account system is built
|
||||||
|
author_id: Mapped[str | None] = mapped_column(
|
||||||
|
UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
|
||||||
|
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
|
||||||
|
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
|
||||||
|
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
|
||||||
|
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
|
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||||
|
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
submitted_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
installations: Mapped[list[PluginInstallation]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
reviews: Mapped[list[PluginReview]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
revenue_events: Mapped[list[RevenueEvent]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInstallation(Base):
|
||||||
|
__tablename__ = "plugin_installations"
|
||||||
|
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
installed_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="installations")
|
||||||
|
|
||||||
|
|
||||||
|
class PluginReview(Base):
|
||||||
|
__tablename__ = "plugin_reviews"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
reviewer_id: Mapped[str | None] = mapped_column(
|
||||||
|
UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
|
||||||
|
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
reviewed_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
|
||||||
|
|
||||||
|
|
||||||
|
class RevenueEvent(Base):
|
||||||
|
__tablename__ = "revenue_events"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
|
||||||
304
tests/test_middleware.py
Normal file
304
tests/test_middleware.py
Normal file
@@ -0,0 +1,304 @@
|
|||||||
|
"""Tests for Step 9 middleware: auth, rate limiting, and sanitizer.
|
||||||
|
|
||||||
|
Auth tests: validated via GET /api/v1/auth/me (requires a Bearer JWT).
|
||||||
|
Rate limit: use unique user UUIDs per test so windows are independent;
|
||||||
|
the free-tier threshold (20 req/min) is exercised directly.
|
||||||
|
Sanitizer: the orchestrator is mocked to inject controlled prompt
|
||||||
|
fragments, and the chat endpoint response body is inspected.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.main import app
|
||||||
|
from app.schemas import ChatResponse
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_CHAT_BODY = {
|
||||||
|
"message": "hello",
|
||||||
|
"context": {
|
||||||
|
"user_profile": {},
|
||||||
|
"relevant_documents": [],
|
||||||
|
"recent_tasks": [],
|
||||||
|
"conversation_history": [],
|
||||||
|
},
|
||||||
|
"execution_mode": "direct",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_jwt(
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
|
email: str = "test@example.com",
|
||||||
|
tier: str = "free",
|
||||||
|
exp_offset: int = 3600,
|
||||||
|
secret: str | None = None,
|
||||||
|
include_sub: bool = True,
|
||||||
|
) -> str:
|
||||||
|
"""Mint a test JWT signed with the configured (or custom) secret."""
|
||||||
|
uid = user_id or str(uuid.uuid4())
|
||||||
|
now = int(time.time())
|
||||||
|
payload: dict = {
|
||||||
|
"email": email,
|
||||||
|
"tier": tier,
|
||||||
|
"exp": now + exp_offset,
|
||||||
|
"iat": now,
|
||||||
|
}
|
||||||
|
if include_sub:
|
||||||
|
payload["sub"] = uid
|
||||||
|
key = secret or settings.JWT_SECRET
|
||||||
|
return jwt.encode(payload, key, algorithm=settings.JWT_ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_header(token: str) -> dict[str, str]:
|
||||||
|
return {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Auth middleware
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthMiddleware:
|
||||||
|
"""Tests exercised via GET /api/v1/auth/me."""
|
||||||
|
|
||||||
|
def test_valid_token_returns_profile(self) -> None:
|
||||||
|
uid = str(uuid.uuid4())
|
||||||
|
token = _make_jwt(user_id=uid, email="alice@example.com", tier="pro")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["id"] == uid
|
||||||
|
assert data["email"] == "alice@example.com"
|
||||||
|
assert data["tier"] == "pro"
|
||||||
|
|
||||||
|
def test_missing_token_returns_401(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/auth/me")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_expired_token_returns_401(self) -> None:
|
||||||
|
token = _make_jwt(exp_offset=-1) # already expired
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_wrong_signature_returns_401(self) -> None:
|
||||||
|
token = _make_jwt(secret="totally-wrong-secret")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_missing_sub_claim_returns_401(self) -> None:
|
||||||
|
token = _make_jwt(include_sub=False)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_malformed_token_returns_401(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/auth/me", headers={"Authorization": "Bearer not.a.jwt"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Rate limiter middleware
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimitMiddleware:
|
||||||
|
"""Each test uses a fresh unique user_id so windows never collide."""
|
||||||
|
|
||||||
|
def _unique_token(self, tier: str = "free") -> str:
|
||||||
|
return _make_jwt(user_id=str(uuid.uuid4()), tier=tier)
|
||||||
|
|
||||||
|
def test_free_tier_allows_up_to_20_requests(self) -> None:
|
||||||
|
token = self._unique_token("free")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
for _ in range(20):
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_free_tier_blocks_21st_request(self) -> None:
|
||||||
|
token = self._unique_token("free")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
for _ in range(20):
|
||||||
|
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 429
|
||||||
|
|
||||||
|
def test_429_includes_retry_after_header(self) -> None:
|
||||||
|
token = self._unique_token("free")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
for _ in range(20):
|
||||||
|
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 429
|
||||||
|
assert "retry-after" in resp.headers
|
||||||
|
retry_after = int(resp.headers["retry-after"])
|
||||||
|
assert retry_after >= 1
|
||||||
|
|
||||||
|
def test_429_response_has_detail_field(self) -> None:
|
||||||
|
token = self._unique_token("free")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
for _ in range(20):
|
||||||
|
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 429
|
||||||
|
assert "detail" in resp.json()
|
||||||
|
|
||||||
|
def test_pro_tier_allows_60_requests(self) -> None:
|
||||||
|
token = self._unique_token("pro")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
# Sample: first 60 succeed, 61st is blocked.
|
||||||
|
for _ in range(60):
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 429
|
||||||
|
|
||||||
|
def test_independent_users_have_separate_windows(self) -> None:
|
||||||
|
token_a = self._unique_token("free")
|
||||||
|
token_b = self._unique_token("free")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
# Exhaust user A's quota.
|
||||||
|
for _ in range(20):
|
||||||
|
client.get("/api/v1/auth/me", headers=_auth_header(token_a))
|
||||||
|
assert (
|
||||||
|
client.get(
|
||||||
|
"/api/v1/auth/me", headers=_auth_header(token_a)
|
||||||
|
).status_code
|
||||||
|
== 429
|
||||||
|
)
|
||||||
|
# User B's quota is untouched.
|
||||||
|
resp_b = client.get("/api/v1/auth/me", headers=_auth_header(token_b))
|
||||||
|
assert resp_b.status_code == 200
|
||||||
|
|
||||||
|
def test_exempt_path_register_never_rate_limited(self) -> None:
|
||||||
|
"""POST /auth/register is exempt — 25 calls should never return 429."""
|
||||||
|
with TestClient(app) as client:
|
||||||
|
for i in range(25):
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": f"user{i}_{uuid.uuid4()}@example.com", "password": "pw"},
|
||||||
|
)
|
||||||
|
# 201 on first, 409 on email collision — but never 429.
|
||||||
|
assert resp.status_code != 429
|
||||||
|
|
||||||
|
def test_exempt_path_login_never_rate_limited(self) -> None:
|
||||||
|
"""POST /auth/login is exempt — multiple failed attempts are not rate-limited."""
|
||||||
|
with TestClient(app) as client:
|
||||||
|
for _ in range(25):
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": "nosuchuser@example.com", "password": "wrong"},
|
||||||
|
)
|
||||||
|
assert resp.status_code != 429
|
||||||
|
|
||||||
|
def test_exempt_path_health_never_rate_limited(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
for _ in range(25):
|
||||||
|
resp = client.get("/api/v1/health")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sanitizer middleware
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSanitizerMiddleware:
|
||||||
|
"""Mock ``orchestrate`` to inject controlled strings into chat responses."""
|
||||||
|
|
||||||
|
_CHAT_PATH = "/api/v1/chat"
|
||||||
|
|
||||||
|
def _token(self) -> str:
|
||||||
|
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
|
||||||
|
|
||||||
|
def _post_chat(self, client: TestClient, response_text: str) -> dict:
|
||||||
|
mock_response = ChatResponse(response=response_text, actions=[])
|
||||||
|
with patch(
|
||||||
|
"app.api.routes.chat.orchestrate",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
):
|
||||||
|
resp = client.post(
|
||||||
|
self._CHAT_PATH,
|
||||||
|
json=_CHAT_BODY,
|
||||||
|
headers=_auth_header(self._token()),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
def test_clean_response_passes_through_unchanged(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
data = self._post_chat(client, "Sure, I created the task for you.")
|
||||||
|
assert data["response"] == "Sure, I created the task for you."
|
||||||
|
|
||||||
|
def test_strips_system_prompt_opener(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
data = self._post_chat(
|
||||||
|
client, "You are an intent classifier. Route to task_agent."
|
||||||
|
)
|
||||||
|
assert "You are" not in data["response"]
|
||||||
|
assert "[REDACTED]" in data["response"]
|
||||||
|
|
||||||
|
def test_strips_known_fingerprint(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
data = self._post_chat(
|
||||||
|
client, "Respond with just the agent name and nothing else."
|
||||||
|
)
|
||||||
|
assert data["response"] == "[REDACTED]"
|
||||||
|
|
||||||
|
def test_strips_tool_schema_fragment(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
data = self._post_chat(
|
||||||
|
client, 'Here is the schema: {"type": "function", "name": "foo"}'
|
||||||
|
)
|
||||||
|
assert '"type": "function"' not in data["response"]
|
||||||
|
|
||||||
|
def test_strips_reasoning_tag(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
data = self._post_chat(
|
||||||
|
client, "<thinking>I should route this to calendar_agent</thinking>Done."
|
||||||
|
)
|
||||||
|
assert "<thinking>" not in data["response"]
|
||||||
|
assert "[REDACTED]" in data["response"]
|
||||||
|
|
||||||
|
def test_strips_available_agents_fragment(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
data = self._post_chat(
|
||||||
|
client, "Available agents: task_agent, calendar_agent"
|
||||||
|
)
|
||||||
|
assert "[REDACTED]" in data["response"]
|
||||||
|
|
||||||
|
def test_sanitizer_does_not_activate_for_non_chat_path(self) -> None:
|
||||||
|
"""GET /api/v1/plans/playbook should pass through the sanitizer untouched."""
|
||||||
|
token = self._token()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/plans/playbook",
|
||||||
|
headers=_auth_header(token),
|
||||||
|
)
|
||||||
|
# The sanitizer should not interfere — just check it returns something
|
||||||
|
# (200 or whatever the route returns; we only care it's not broken).
|
||||||
|
assert resp.status_code in (200, 401, 403, 404)
|
||||||
|
|
||||||
|
def test_sanitizer_preserves_empty_response(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
data = self._post_chat(client, "")
|
||||||
|
assert data["response"] == ""
|
||||||
387
tests/test_plugins.py
Normal file
387
tests/test_plugins.py
Normal file
@@ -0,0 +1,387 @@
|
|||||||
|
"""Tests for Step 10: Plugin Marketplace.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- PluginRegistry: catalog management, filtering, sorting, install counts
|
||||||
|
- ReviewQueue: pending queue, review decisions, manifest security checklist
|
||||||
|
- RevenueShare: install event recording, earnings aggregation
|
||||||
|
- Route integration: tier gate, list/get/install/uninstall via TestClient
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from jose import jwt
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.main import app
|
||||||
|
from app.marketplace.plugin_registry import PluginRegistry
|
||||||
|
from app.marketplace.plugin_review import ReviewQueue, validate_manifest
|
||||||
|
from app.marketplace.revenue_share import RevenueShare
|
||||||
|
from app.schemas import PluginManifest
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_jwt(tier: str = "power", user_id: str | None = None) -> str:
|
||||||
|
uid = user_id or str(uuid.uuid4())
|
||||||
|
now = int(time.time())
|
||||||
|
payload = {
|
||||||
|
"sub": uid,
|
||||||
|
"email": f"{uid[:8]}@example.com",
|
||||||
|
"tier": tier,
|
||||||
|
"exp": now + 3600,
|
||||||
|
"iat": now,
|
||||||
|
}
|
||||||
|
return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
|
def _auth(tier: str = "power") -> dict[str, str]:
|
||||||
|
return {"Authorization": f"Bearer {_make_jwt(tier)}"}
|
||||||
|
|
||||||
|
|
||||||
|
def _fresh_manifest(
|
||||||
|
plugin_id: str | None = None,
|
||||||
|
category: str = "productivity",
|
||||||
|
price_cents: int = 0,
|
||||||
|
permissions: list[str] | None = None,
|
||||||
|
) -> PluginManifest:
|
||||||
|
pid = plugin_id or f"plugin-{uuid.uuid4().hex[:8]}"
|
||||||
|
return PluginManifest(
|
||||||
|
id=pid,
|
||||||
|
name=f"Plugin {pid}",
|
||||||
|
description=f"Description for {pid}",
|
||||||
|
version="1.0.0",
|
||||||
|
author="test-author",
|
||||||
|
permissions=permissions or ["read:tasks"],
|
||||||
|
category=category,
|
||||||
|
price_cents=price_cents,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PluginRegistry
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPluginRegistry:
|
||||||
|
"""Each test uses a fresh PluginRegistry instance to avoid catalog pollution."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def reg(self) -> PluginRegistry:
|
||||||
|
return PluginRegistry()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_seed_plugins_are_approved(self, reg: PluginRegistry) -> None:
|
||||||
|
result = await reg.list_plugins()
|
||||||
|
assert result.total == 3
|
||||||
|
assert all(p.id.startswith("plugin-") for p in result.plugins)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_approved_only(self, reg: PluginRegistry) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(manifest, "plugins/key.zip")
|
||||||
|
result = await reg.list_plugins()
|
||||||
|
ids = [p.id for p in result.plugins]
|
||||||
|
assert manifest.id not in ids # still pending
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_filter_by_category(self, reg: PluginRegistry) -> None:
|
||||||
|
result = await reg.list_plugins(category="communication")
|
||||||
|
assert result.total == 1
|
||||||
|
assert result.plugins[0].id == "plugin-slack-notify"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_filter_by_query(self, reg: PluginRegistry) -> None:
|
||||||
|
result = await reg.list_plugins(query="time")
|
||||||
|
assert result.total == 1
|
||||||
|
assert result.plugins[0].id == "plugin-time-tracker"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_sort_by_installs(self, reg: PluginRegistry) -> None:
|
||||||
|
await reg.record_install("plugin-slack-notify")
|
||||||
|
await reg.record_install("plugin-slack-notify")
|
||||||
|
result = await reg.list_plugins(sort="installs")
|
||||||
|
assert result.plugins[0].id == "plugin-slack-notify"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_plugin_found(self, reg: PluginRegistry) -> None:
|
||||||
|
entry = await reg.get_plugin("plugin-github-sync")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry["manifest"].id == "plugin-github-sync"
|
||||||
|
assert "install_count" in entry
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_plugin_not_found(self, reg: PluginRegistry) -> None:
|
||||||
|
entry = await reg.get_plugin("no-such-plugin")
|
||||||
|
assert entry is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit_sets_pending(self, reg: PluginRegistry) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
plugin_id = await reg.submit_plugin(manifest, "key.zip")
|
||||||
|
assert plugin_id == manifest.id
|
||||||
|
assert reg._catalog[plugin_id]["status"] == "pending_review"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_approve_makes_visible(self, reg: PluginRegistry) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(manifest, "key.zip")
|
||||||
|
await reg.approve_plugin(manifest.id)
|
||||||
|
result = await reg.list_plugins()
|
||||||
|
assert manifest.id in [p.id for p in result.plugins]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_stores_reason(self, reg: PluginRegistry) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(manifest, "key.zip")
|
||||||
|
await reg.reject_plugin(manifest.id, reason="Unsafe permissions")
|
||||||
|
assert reg._catalog[manifest.id]["status"] == "rejected"
|
||||||
|
assert reg._catalog[manifest.id]["rejection_reason"] == "Unsafe permissions"
|
||||||
|
result = await reg.list_plugins()
|
||||||
|
assert manifest.id not in [p.id for p in result.plugins]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_approve_unknown_raises_key_error(self, reg: PluginRegistry) -> None:
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
await reg.approve_plugin("ghost-plugin")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_install_increments_count(self, reg: PluginRegistry) -> None:
|
||||||
|
await reg.record_install("plugin-github-sync")
|
||||||
|
entry = await reg.get_plugin("plugin-github-sync")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry["install_count"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_uninstall_decrements_count(self, reg: PluginRegistry) -> None:
|
||||||
|
await reg.record_install("plugin-github-sync")
|
||||||
|
await reg.record_install("plugin-github-sync")
|
||||||
|
await reg.record_uninstall("plugin-github-sync")
|
||||||
|
entry = await reg.get_plugin("plugin-github-sync")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry["install_count"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_uninstall_floors_at_zero(self, reg: PluginRegistry) -> None:
|
||||||
|
await reg.record_uninstall("plugin-github-sync") # already 0
|
||||||
|
entry = await reg.get_plugin("plugin-github-sync")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry["install_count"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ReviewQueue
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestReviewQueue:
|
||||||
|
@pytest.fixture
|
||||||
|
def reg(self) -> PluginRegistry:
|
||||||
|
return PluginRegistry()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def queue(self, reg: PluginRegistry) -> ReviewQueue:
|
||||||
|
# Patch the 'registry' name as bound inside plugin_review.py
|
||||||
|
with patch("app.marketplace.plugin_review.registry", reg):
|
||||||
|
yield ReviewQueue()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_pending_returns_submitted_plugins(
|
||||||
|
self, reg: PluginRegistry, queue: ReviewQueue
|
||||||
|
) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(manifest, "key.zip")
|
||||||
|
pending = await queue.get_pending()
|
||||||
|
assert any(p["plugin_id"] == manifest.id for p in pending)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit_review_approved(
|
||||||
|
self, reg: PluginRegistry, queue: ReviewQueue
|
||||||
|
) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(manifest, "key.zip")
|
||||||
|
await queue.submit_review(manifest.id, "reviewer-1", "approved", "Looks good")
|
||||||
|
assert reg._catalog[manifest.id]["status"] == "approved"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit_review_rejected(
|
||||||
|
self, reg: PluginRegistry, queue: ReviewQueue
|
||||||
|
) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(manifest, "key.zip")
|
||||||
|
await queue.submit_review(manifest.id, "reviewer-1", "rejected", "Bad permissions")
|
||||||
|
assert reg._catalog[manifest.id]["status"] == "rejected"
|
||||||
|
|
||||||
|
def test_validate_manifest_ok(self) -> None:
|
||||||
|
manifest = _fresh_manifest(permissions=["read:tasks", "write:notes"])
|
||||||
|
validate_manifest(manifest) # should not raise
|
||||||
|
|
||||||
|
def test_validate_manifest_unknown_permission(self) -> None:
|
||||||
|
manifest = _fresh_manifest(permissions=["read:tasks", "read:secrets"])
|
||||||
|
with pytest.raises(ValueError, match="Unknown permission"):
|
||||||
|
validate_manifest(manifest)
|
||||||
|
|
||||||
|
def test_validate_manifest_invalid_id_format(self) -> None:
|
||||||
|
manifest = _fresh_manifest(plugin_id="Plugin_ID_Invalid")
|
||||||
|
with pytest.raises(ValueError, match="Invalid plugin id format"):
|
||||||
|
validate_manifest(manifest)
|
||||||
|
|
||||||
|
def test_validate_manifest_id_with_uppercase(self) -> None:
|
||||||
|
manifest = _fresh_manifest(plugin_id="UpperCase")
|
||||||
|
with pytest.raises(ValueError, match="Invalid plugin id format"):
|
||||||
|
validate_manifest(manifest)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RevenueShare
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRevenueShare:
|
||||||
|
@pytest.fixture
|
||||||
|
def reg(self) -> PluginRegistry:
|
||||||
|
return PluginRegistry()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def rs(self, reg: PluginRegistry) -> RevenueShare:
|
||||||
|
# Patch the 'registry' name as bound inside revenue_share.py
|
||||||
|
with patch("app.marketplace.revenue_share.registry", reg):
|
||||||
|
yield RevenueShare()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_install_free_plugin(
|
||||||
|
self, reg: PluginRegistry, rs: RevenueShare
|
||||||
|
) -> None:
|
||||||
|
await rs.record_install("plugin-github-sync", "user-1", amount_cents=0)
|
||||||
|
assert len(rs._events) == 1
|
||||||
|
assert rs._events[0]["developer_share_cents"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_install_paid_plugin_no_stripe(
|
||||||
|
self, reg: PluginRegistry, rs: RevenueShare
|
||||||
|
) -> None:
|
||||||
|
# No STRIPE_SECRET_KEY configured in test env — should not crash
|
||||||
|
await rs.record_install("plugin-slack-notify", "user-2", amount_cents=499)
|
||||||
|
assert len(rs._events) == 1
|
||||||
|
assert rs._events[0]["amount_cents"] == 499
|
||||||
|
assert rs._events[0]["developer_share_cents"] == int(499 * 0.70)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_install_increments_registry_count(
|
||||||
|
self, reg: PluginRegistry, rs: RevenueShare
|
||||||
|
) -> None:
|
||||||
|
await rs.record_install("plugin-github-sync", "user-1", amount_cents=0)
|
||||||
|
entry = await reg.get_plugin("plugin-github-sync")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry["install_count"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_earnings_empty(
|
||||||
|
self, reg: PluginRegistry, rs: RevenueShare
|
||||||
|
) -> None:
|
||||||
|
result = await rs.get_earnings("unknown-dev")
|
||||||
|
assert result["total_installs"] == 0
|
||||||
|
assert result["total_revenue_cents"] == 0
|
||||||
|
assert result["developer_share_cents"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_earnings_aggregates(
|
||||||
|
self, reg: PluginRegistry, rs: RevenueShare
|
||||||
|
) -> None:
|
||||||
|
# "Adiuva" is the author of the seeded plugins
|
||||||
|
await rs.record_install("plugin-slack-notify", "u1", amount_cents=499)
|
||||||
|
await rs.record_install("plugin-slack-notify", "u2", amount_cents=499)
|
||||||
|
result = await rs.get_earnings("Adiuva")
|
||||||
|
assert result["total_installs"] == 2
|
||||||
|
assert result["total_revenue_cents"] == 998
|
||||||
|
assert result["developer_share_cents"] == int(499 * 0.70) * 2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Route integration tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPluginRoutes:
|
||||||
|
def test_list_plugins_requires_power_tier(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/plugins", headers=_auth("free"))
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
def test_list_plugins_pro_tier_blocked(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/plugins", headers=_auth("pro"))
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
def test_list_plugins_power_tier_ok(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/plugins", headers=_auth("power"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "plugins" in data
|
||||||
|
assert data["total"] >= 3
|
||||||
|
|
||||||
|
def test_list_plugins_team_tier_ok(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/plugins", headers=_auth("team"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_get_plugin_found(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/plugins/plugin-github-sync", headers=_auth())
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["plugin"]["id"] == "plugin-github-sync"
|
||||||
|
assert "install_count" in data
|
||||||
|
|
||||||
|
def test_get_plugin_not_found(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/plugins/no-such-plugin", headers=_auth())
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_install_plugin_free(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/plugins/plugin-github-sync/install",
|
||||||
|
json={"plugin_id": "plugin-github-sync"},
|
||||||
|
headers=_auth(),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["ok"] is True
|
||||||
|
assert "download_url" in data
|
||||||
|
|
||||||
|
def test_install_plugin_not_found(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/plugins/ghost/install",
|
||||||
|
json={"plugin_id": "ghost"},
|
||||||
|
headers=_auth(),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_uninstall_plugin_ok(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.delete(
|
||||||
|
"/api/v1/plugins/plugin-github-sync/install",
|
||||||
|
headers=_auth(),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["ok"] is True
|
||||||
|
|
||||||
|
def test_install_requires_power_tier(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/plugins/plugin-github-sync/install",
|
||||||
|
json={"plugin_id": "plugin-github-sync"},
|
||||||
|
headers=_auth("free"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
Reference in New Issue
Block a user