Compare commits

..

4 Commits

Author SHA1 Message Date
5d485b3665 step 12 2026-03-03 12:39:32 +01:00
9787befd4a step 11 complete: billing service and tier manager
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 22:41:35 +01:00
8f7bc25611 step 10 complete: plugin marketplace with catalog, review workflow, and revenue split
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 22:32:44 +01:00
3e07fff958 step 9 complete: auth middleware, tier-aware rate limiter, and response sanitizer
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 22:18:17 +01:00
27 changed files with 2920 additions and 342 deletions

View File

@@ -331,14 +331,14 @@ adiuva-api/
### Step 9 — Middleware
#### 9a — Auth middleware
- [ ] `app/api/middleware/auth.py`:
- [x] `app/api/middleware/auth.py`:
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
- Validates JWT signature, expiry, extracts `user_id` and `tier`
- Raises `401` on invalid/expired token
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
#### 9b — Rate limiter
- [ ] `app/api/middleware/rate_limit.py`:
- [x] `app/api/middleware/rate_limit.py`:
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
- Tier-based limits:
- Free: 20 req/min
@@ -348,7 +348,7 @@ adiuva-api/
- Custom 429 response with `Retry-After` header
#### 9c — Sanitizer
- [ ] `app/api/middleware/sanitizer.py`:
- [x] `app/api/middleware/sanitizer.py`:
- Response middleware that scans response bodies
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
- Pattern-based detection + exact match against known prompt fingerprints
@@ -356,33 +356,33 @@ adiuva-api/
- **Outcome:** Secure, rate-limited API with prompt IP protection.
### Step 10 — Plugin Marketplace
- [ ] `app/marketplace/plugin_registry.py`:
### Step 10 — Plugin Marketplace
- [x] `app/marketplace/plugin_registry.py`:
- `PluginRegistry`:
- `async list_plugins(category, query, page, sort) -> PluginListResponse`
- `async get_plugin(plugin_id) -> PluginManifest | None`
- `async submit_plugin(manifest: PluginManifest, package_s3_key: str) -> str` — returns plugin_id, sets status = 'pending_review'
- `async approve_plugin(plugin_id) -> None` — admin only, sets status = 'approved'
- `async reject_plugin(plugin_id, reason: str) -> None`
- [ ] `app/marketplace/plugin_review.py`:
- [x] `app/marketplace/plugin_review.py`:
- `ReviewQueue`:
- `async get_pending() -> list[dict]`
- `async submit_review(plugin_id, reviewer_id, decision, notes) -> None`
- Security checklist enforced before approval: manifest schema valid, permissions are from allowed set, no binary blobs in manifest
- [ ] `app/marketplace/revenue_share.py`:
- [x] `app/marketplace/revenue_share.py`:
- `RevenueShare`:
- `async record_install(plugin_id, user_id, amount_cents) -> None`
- `async payout_developer(plugin_id, period) -> None` — Stripe Connect transfer: 70% to developer
- `async get_earnings(developer_id, period) -> dict`
- **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split.
### Step 11 — Billing & Tier management
- [ ] `app/billing/stripe_service.py`:
### Step 11 — Billing & Tier management
- [x] `app/billing/stripe_service.py`:
- `create_checkout_session(user_id, tier) -> str`
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
- `get_subscription(user_id) -> dict | None`
- `cancel_subscription(user_id) -> None`
- [ ] `app/billing/tier_manager.py`:
- [x] `app/billing/tier_manager.py`:
- `TierManager`:
- Feature matrix:
```python
@@ -433,6 +433,9 @@ adiuva-api/
- `check_feature(user_id, feature) -> bool`
- `get_rate_limit(tier) -> int`
- `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit
- [x] `app/billing/__init__.py`: exports `stripe_service` and `tier_manager` singletons
- [x] `app/api/routes/billing.py`: refactored to delegate to `StripeService`
- [x] `app/api/routes/storage.py` and `backup.py`: `_check_quota` now delegates to `tier_manager.enforce_quota` / `enforce_backup_quota`
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
### Step 12 — Database (auth/billing/marketplace only)

47
alembic.ini Normal file
View File

@@ -0,0 +1,47 @@
# Alembic configuration file.
# The async app uses postgresql+asyncpg:// at runtime.
# Alembic CLI uses the sync psycopg2 URL set in env.py (reads from DATABASE_URL env var).
[alembic]
script_location = alembic
prepend_sys_path = .
version_path_separator = os
# sqlalchemy.url is overridden in alembic/env.py — leave as placeholder.
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

93
alembic/env.py Normal file
View File

@@ -0,0 +1,93 @@
"""Alembic migration environment — async-compatible.
At runtime the app uses ``postgresql+asyncpg://``. Alembic's CLI is
synchronous, so we derive a *sync* psycopg2 URL from the same DATABASE_URL
env var by replacing the driver prefix.
Run migrations with:
alembic upgrade head
"""
from __future__ import annotations
import asyncio
import os
import re
from logging.config import fileConfig
from alembic import context
from sqlalchemy import engine_from_config, pool
from sqlalchemy.ext.asyncio import create_async_engine
# Alembic Config object (gives access to alembic.ini values).
config = context.config
# Set up Python logging from alembic.ini.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# Import the Base so that Alembic can detect model changes for --autogenerate.
from app.models import Base # noqa: E402
target_metadata = Base.metadata
def _sync_url(async_url: str) -> str:
"""Convert an asyncpg URL to a psycopg2 URL for Alembic CLI."""
return re.sub(r"postgresql\+asyncpg", "postgresql+psycopg2", async_url)
def _get_url() -> str:
db_url = os.environ.get("DATABASE_URL", "")
if not db_url:
# Fall back to settings if env var not set directly.
from app.config.settings import settings # noqa: PLC0415
db_url = settings.DATABASE_URL
return _sync_url(db_url)
def run_migrations_offline() -> None:
"""Emit SQL without a live DB connection."""
url = _get_url()
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection): # type: ignore[no-untyped-def]
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online_async() -> None:
"""Run migrations against a live DB using the async engine."""
async_url = os.environ.get("DATABASE_URL", "")
if not async_url:
from app.config.settings import settings # noqa: PLC0415
async_url = settings.DATABASE_URL
connectable = create_async_engine(async_url, poolclass=pool.NullPool)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
asyncio.run(run_migrations_online_async())
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

28
alembic/script.py.mako Normal file
View File

@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from __future__ import annotations
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,202 @@
"""Initial schema: users, refresh_tokens, subscriptions, storage_records,
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
Revision ID: 001
Revises:
Create Date: 2026-03-02
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "001"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ── Enum types ────────────────────────────────────────────────────────
billing_tier = postgresql.ENUM(
"free", "pro", "power", "team", name="billing_tier", create_type=False
)
plugin_status = postgresql.ENUM(
"pending_review", "approved", "rejected", name="plugin_status", create_type=False
)
review_decision = postgresql.ENUM(
"approved", "rejected", name="review_decision", create_type=False
)
for enum in (billing_tier, plugin_status, review_decision):
enum.create(op.get_bind(), checkfirst=True)
# ── users ─────────────────────────────────────────────────────────────
op.create_table(
"users",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("email", sa.String(255), nullable=False),
sa.Column("password_hash", sa.String(255), nullable=False),
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier"), nullable=False, server_default="free"),
sa.Column("stripe_customer_id", sa.String(255), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("email"),
)
op.create_index("ix_users_email", "users", ["email"])
# ── refresh_tokens ────────────────────────────────────────────────────
op.create_table(
"refresh_tokens",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("token_hash", sa.String(64), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.UniqueConstraint("token_hash"),
)
op.create_index("ix_refresh_tokens_user_id", "refresh_tokens", ["user_id"])
op.create_index("ix_refresh_tokens_token_hash", "refresh_tokens", ["token_hash"])
# ── subscriptions ─────────────────────────────────────────────────────
op.create_table(
"subscriptions",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("stripe_subscription_id", sa.String(255), nullable=True),
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier"), nullable=False, server_default="free"),
sa.Column("status", sa.String(50), nullable=False, server_default="free"),
sa.Column("current_period_end", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.UniqueConstraint("user_id"),
)
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
# ── storage_records ───────────────────────────────────────────────────
op.create_table(
"storage_records",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("table_name", sa.String(100), nullable=False),
sa.Column("s3_key", sa.String(500), nullable=False),
sa.Column("checksum", sa.String(64), nullable=False),
sa.Column("size_bytes", sa.Integer, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
# ── backup_metadata ───────────────────────────────────────────────────
op.create_table(
"backup_metadata",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("s3_key", sa.String(500), nullable=False),
sa.Column("version", sa.Integer, nullable=False),
sa.Column("timestamp", sa.BigInteger, nullable=False),
sa.Column("checksum", sa.String(64), nullable=False),
sa.Column("size_bytes", sa.Integer, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
# ── plugins ───────────────────────────────────────────────────────────
op.create_table(
"plugins",
sa.Column("id", sa.String(255), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("description", sa.Text, nullable=False, server_default=""),
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
sa.Column("category", sa.String(100), nullable=False, server_default=""),
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
sa.Column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status"), nullable=False, server_default="pending_review"),
sa.Column("s3_package_key", sa.String(500), nullable=True),
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
sa.Column("rejection_reason", sa.Text, nullable=True),
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
)
# ── plugin_installations ──────────────────────────────────────────────
op.create_table(
"plugin_installations",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
)
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
# ── plugin_reviews ────────────────────────────────────────────────────
op.create_table(
"plugin_reviews",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
sa.Column("decision", sa.Enum("approved", "rejected", name="review_decision"), nullable=False),
sa.Column("notes", sa.Text, nullable=True),
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
)
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
# ── revenue_events ────────────────────────────────────────────────────
op.create_table(
"revenue_events",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
def downgrade() -> None:
op.drop_table("revenue_events")
op.drop_table("plugin_reviews")
op.drop_table("plugin_installations")
op.drop_table("plugins")
op.drop_table("backup_metadata")
op.drop_table("storage_records")
op.drop_table("subscriptions")
op.drop_table("refresh_tokens")
op.drop_table("users")
op.execute("DROP TYPE IF EXISTS review_decision")
op.execute("DROP TYPE IF EXISTS plugin_status")
op.execute("DROP TYPE IF EXISTS billing_tier")

View File

@@ -1,46 +1,14 @@
"""Shared FastAPI dependencies.
``get_current_user`` decodes the Bearer JWT and returns a ``UserProfile``.
Step 9 will layer rate-limiting and sanitization middleware on top of this.
Step 12 will add a DB look-up to fetch the live tier from PostgreSQL.
``get_current_user`` and ``oauth2_scheme`` live in ``app.api.middleware.auth``
(the canonical location per Step 9). This module re-exports them so that all
existing route imports (``from app.api.deps import get_current_user``) continue
to work without modification.
Step 12 will update ``get_current_user`` to fetch the live tier from PostgreSQL
instead of reading it from the JWT payload.
"""
from __future__ import annotations
from app.api.middleware.auth import get_current_user, oauth2_scheme # noqa: F401
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from app.config.settings import settings
from app.schemas import BillingTier, UserProfile
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
) -> UserProfile:
"""Validate a Bearer JWT and return the authenticated user.
Raises ``HTTP 401`` on any invalid or expired token.
The tier embedded in the JWT is used for feature-gating until Step 12
adds a live DB lookup.
"""
credentials_exc = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
tier: str = payload.get("tier", "free")
if not user_id or not email:
raise credentials_exc
except JWTError:
raise credentials_exc
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]
__all__ = ["get_current_user", "oauth2_scheme"]

View File

@@ -0,0 +1,19 @@
"""API middleware package.
Exports the three middleware components introduced in Step 9:
- Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme``
- Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter)
- Sanitizer: ``SanitizerMiddleware``
"""
from app.api.middleware.auth import get_current_user, oauth2_scheme
from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter
from app.api.middleware.sanitizer import SanitizerMiddleware
__all__ = [
"get_current_user",
"oauth2_scheme",
"TierRateLimitMiddleware",
"limiter",
"SanitizerMiddleware",
]

View File

@@ -0,0 +1,65 @@
"""Auth middleware — JWT validation dependency.
``get_current_user`` is the FastAPI dependency used by all protected routes.
It decodes the Bearer JWT (identity + expiry), then fetches the current tier
from the ``subscriptions`` table so that tier changes take effect immediately
without requiring token re-issue.
Exempt routes (no JWT required):
- POST /api/v1/auth/register
- POST /api/v1/auth/login
- POST /api/v1/billing/webhook
"""
from __future__ import annotations
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
from app.db import get_session
from app.schemas import UserProfile
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Validate a Bearer JWT and return the authenticated user.
The JWT is used for identity and expiry only. The tier is fetched live
from the ``subscriptions`` table so that upgrades/downgrades take effect
immediately. Falls back to ``'free'`` when no subscription row exists.
Raises HTTP 401 on any invalid or expired token.
"""
credentials_exc = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
if not user_id or not email:
raise credentials_exc
except JWTError:
raise credentials_exc
# Live tier lookup — subscription row is the authoritative source.
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
tier: str = result.scalar_one_or_none() or "free"
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]

View File

@@ -0,0 +1,129 @@
"""Tier-aware rate limiting middleware.
Uses a per-user sliding-window counter (in-process, no Redis required).
The ``slowapi`` Limiter is also exported for optional route-level decoration.
Limits (requests per minute):
- free: 20
- pro: 60
- power: 120
- team: 200
Exempt paths bypass the limiter entirely:
- POST /api/v1/auth/register
- POST /api/v1/auth/login
- POST /api/v1/billing/webhook
- GET /api/v1/health
"""
from __future__ import annotations
import json
import time
from collections import defaultdict
from fastapi import Request, Response
from jose import JWTError, jwt
from slowapi import Limiter
from slowapi.util import get_remote_address
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from app.config.settings import settings
_TIER_LIMITS: dict[str, int] = {
"free": 20,
"pro": 60,
"power": 120,
"team": 200,
}
_EXEMPT_PATHS: frozenset[str] = frozenset(
{
"/api/v1/auth/register",
"/api/v1/auth/login",
"/api/v1/billing/webhook",
"/api/v1/health",
}
)
def _get_user_id_from_jwt(request: Request) -> str:
"""Key function for the slowapi Limiter: returns JWT sub or remote IP."""
auth = request.headers.get("Authorization", "")
token = auth.removeprefix("Bearer ").strip()
if not token:
return get_remote_address(request)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
return payload.get("sub") or get_remote_address(request)
except JWTError:
return get_remote_address(request)
# Exported Limiter instance — available for optional route-level decoration.
limiter = Limiter(key_func=_get_user_id_from_jwt)
class TierRateLimitMiddleware(BaseHTTPMiddleware):
"""Sliding-window rate limiter applied globally across all non-exempt routes.
Each authenticated user gets their own 60-second window sized by tier.
Unauthenticated requests pass through (the auth dependency will reject them
with 401 before the route handler runs).
"""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
# user_id → list of request timestamps (float, seconds since epoch)
self._window: dict[str, list[float]] = defaultdict(list)
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
if request.url.path in _EXEMPT_PATHS:
return await call_next(request)
# Extract JWT claims — if no valid token, pass through for auth dep to handle.
auth = request.headers.get("Authorization", "")
token = auth.removeprefix("Bearer ").strip()
if not token:
return await call_next(request)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str = payload.get("sub") or get_remote_address(request)
tier: str = payload.get("tier", "free")
except JWTError:
return await call_next(request)
limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"])
now = time.monotonic()
window_start = now - 60.0
# Slide the window: discard timestamps older than 60 seconds.
timestamps = [t for t in self._window[user_id] if t > window_start]
if len(timestamps) >= limit:
retry_after = max(1, int(60 - (now - min(timestamps))))
return Response(
content=json.dumps(
{
"detail": (
f"Rate limit exceeded ({limit} req/min for {tier} tier). "
f"Retry in {retry_after}s."
)
}
),
status_code=429,
headers={
"Retry-After": str(retry_after),
"Content-Type": "application/json",
},
)
timestamps.append(now)
self._window[user_id] = timestamps
return await call_next(request)

View File

@@ -0,0 +1,139 @@
"""Response sanitizer middleware.
Scans JSON responses from the /api/v1/chat endpoint and strips any fragments
that could reveal server-side prompt IP:
- System prompt openers ("You are a/an/the …")
- Agent routing metadata ("Available agents:", "intent classifier", …)
- LangChain tool schema fragments (``"type": "function"``)
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
- Exact-match known prompt fingerprints
Binary responses (storage blobs, backup data) are never touched — the
middleware only activates for paths under /api/v1/chat.
Any sanitisation event is logged as a WARNING with the request path and the
names of the fields that were modified.
"""
from __future__ import annotations
import json
import logging
import re
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Detection patterns — order matters: fingerprints checked first (exact),
# then compiled regexes.
# ---------------------------------------------------------------------------
_FINGERPRINTS: tuple[str, ...] = (
"You are an intent classifier",
"Respond with just the agent name",
"Summarize these agent results",
"Available agents:",
"route to:",
)
_PATTERNS: tuple[re.Pattern[str], ...] = (
re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL),
re.compile(r"Available agents\s*:", re.IGNORECASE),
re.compile(r"\bintent classifier\b", re.IGNORECASE),
re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema
re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE),
re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers
re.compile(r"route\s+to\s*:", re.IGNORECASE),
re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE),
)
def _sanitize_text(text: str) -> tuple[str, bool]:
"""Scan *text* for prompt fragments and replace matches with ``[REDACTED]``.
Returns ``(cleaned_text, was_changed)``.
"""
# Fingerprint check — if any exact phrase is present, redact the whole string.
for fp in _FINGERPRINTS:
if fp in text:
return "[REDACTED]", True
changed = False
for pattern in _PATTERNS:
new_text, n = pattern.subn("[REDACTED]", text)
if n:
text = new_text
changed = True
return text, changed
class SanitizerMiddleware(BaseHTTPMiddleware):
"""Strip prompt IP from /api/v1/chat JSON responses."""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
response: Response = await call_next(request)
# Only process chat endpoint responses.
if not request.url.path.startswith("/api/v1/chat"):
return response
# Read body — collect streaming chunks.
body_bytes = b""
async for chunk in response.body_iterator:
body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode()
# Skip non-JSON bodies (shouldn't happen on /chat, but be safe).
try:
body = json.loads(body_bytes.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
return Response(
content=body_bytes,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
if not isinstance(body, dict):
return Response(
content=body_bytes,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
# Walk top-level string fields and sanitise.
sanitised_fields: list[str] = []
for key, value in body.items():
if isinstance(value, str):
cleaned, changed = _sanitize_text(value)
if changed:
body[key] = cleaned
sanitised_fields.append(key)
if sanitised_fields:
logger.warning(
"Sanitizer redacted prompt fragments",
extra={
"path": request.url.path,
"fields": sanitised_fields,
},
)
new_body = json.dumps(body).encode("utf-8")
headers = dict(response.headers)
headers["content-length"] = str(len(new_body))
return Response(
content=new_body,
status_code=response.status_code,
headers=headers,
media_type="application/json",
)

View File

@@ -1,33 +1,36 @@
"""Auth routes: register, login, refresh, me.
Users and refresh tokens are kept in an in-memory dict until Step 12
migrates them to PostgreSQL.
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
SHA-256 hashes so plaintext never reaches the DB.
"""
from __future__ import annotations
import hashlib
import time
import uuid
from typing import Any
from datetime import datetime, timedelta, timezone
import bcrypt
from fastapi import APIRouter, Depends, HTTPException, status
from jose import jwt
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.config.settings import settings
from app.db import get_session
from app.models import RefreshToken, User
from app.schemas import AuthTokens, UserProfile
router = APIRouter(prefix="/auth", tags=["auth"])
# ── In-memory stores (replaced by PostgreSQL in Step 12) ─────────────
_users: dict[str, dict[str, Any]] = {} # email → user record
_refresh_tokens: dict[str, str] = {} # plain token → user_id
# ── Internal helpers ─────────────────────────────────────────────────
def _hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
@@ -36,30 +39,29 @@ def _verify_password(password: str, hashed: str) -> bool:
return bcrypt.checkpw(password.encode(), hashed.encode())
def _make_tokens(user_id: str, email: str, tier: str) -> AuthTokens:
def _hash_token(plain_token: str) -> str:
"""SHA-256 of the plain refresh token string."""
return hashlib.sha256(plain_token.encode()).hexdigest()
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
"""Return (signed JWT, expires_at_ms)."""
now = int(time.time())
access_exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
access_payload = {
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
payload = {
"sub": user_id,
"email": email,
"tier": tier,
"exp": access_exp,
"exp": exp,
"iat": now,
}
access_token = jwt.encode(
access_payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
)
refresh_token = str(uuid.uuid4())
_refresh_tokens[refresh_token] = user_id
return AuthTokens(
access_token=access_token,
refresh_token=refresh_token,
expires_at=access_exp * 1000, # milliseconds for client
)
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
return token, exp * 1000 # ms for client
# ── Request bodies ────────────────────────────────────────────────────
class _RegisterRequest(BaseModel):
email: str
password: str
@@ -76,40 +78,117 @@ class _RefreshRequest(BaseModel):
# ── Routes ────────────────────────────────────────────────────────────
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
async def register(body: _RegisterRequest) -> AuthTokens:
async def register(
body: _RegisterRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Create a new account and return JWT tokens."""
if body.email in _users:
existing = await db.execute(select(User).where(User.email == body.email))
if existing.scalar_one_or_none() is not None:
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
user_id = str(uuid.uuid4())
_users[body.email] = {
"id": user_id,
"email": body.email,
"password_hash": _hash_password(body.password),
"tier": "free",
}
return _make_tokens(user_id, body.email, "free")
user = User(
id=str(uuid.uuid4()),
email=body.email,
password_hash=_hash_password(body.password),
tier="free",
)
db.add(user)
await db.flush() # get user.id without committing
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.post("/login", response_model=AuthTokens)
async def login(body: _LoginRequest) -> AuthTokens:
async def login(
body: _LoginRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Validate credentials and return JWT tokens."""
user = _users.get(body.email)
if not user or not _verify_password(body.password, user["password_hash"]):
result = await db.execute(select(User).where(User.email == body.email))
user = result.scalar_one_or_none()
if user is None or not _verify_password(body.password, user.password_hash):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
return _make_tokens(user["id"], user["email"], user["tier"])
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.post("/refresh", response_model=AuthTokens)
async def refresh(body: _RefreshRequest) -> AuthTokens:
async def refresh(
body: _RefreshRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Rotate a refresh token and return a new token pair."""
user_id = _refresh_tokens.pop(body.refresh_token, None)
if user_id is None:
token_hash = _hash_token(body.refresh_token)
result = await db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
rt = result.scalar_one_or_none()
now = datetime.now(timezone.utc)
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
user = next((u for u in _users.values() if u["id"] == user_id), None)
# Rotate: delete old token, issue new one.
await db.delete(rt)
user_result = await db.execute(select(User).where(User.id == rt.user_id))
user = user_result.scalar_one_or_none()
if user is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
return _make_tokens(user["id"], user["email"], user["tier"])
plain_token = str(uuid.uuid4())
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
new_rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=new_expires,
)
db.add(new_rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.get("/me", response_model=UserProfile)

View File

@@ -16,6 +16,7 @@ from typing import Any
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.schemas import BackupMetadata, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
@@ -27,32 +28,11 @@ _blob_store = BlobStore()
# In-memory backup metadata — replaced by PostgreSQL backup_metadata table in Step 12
_backups: dict[str, list[dict[str, Any]]] = {} # user_id → list of backup records
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
_TIER_BACKUP_LIMITS_GB: dict[str, int] = {
"free": 0,
"pro": 5,
"power": 25,
"team": -1, # unlimited
}
def _check_backup_quota(user_id: str, tier: str, size_bytes: int) -> None:
def _check_backup_quota(user_id: str, size_bytes: int) -> None:
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
limit_gb = _TIER_BACKUP_LIMITS_GB.get(tier, 0)
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail="Backup is not available on the free tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024**3
used = sum(b["size_bytes"] for b in _backups.get(user_id, []))
if used + size_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup quota exceeded for tier '{tier}'",
)
current = sum(b["size_bytes"] for b in _backups.get(user_id, []))
tier_manager.enforce_backup_quota(user_id, current_bytes=current, additional_bytes=size_bytes)
@router.put("")
@@ -69,7 +49,7 @@ async def upload_backup(
"""
blob = await request.body()
reject_if_tampered(blob, x_backup_checksum)
_check_backup_quota(current_user.id, current_user.tier, len(blob))
_check_backup_quota(current_user.id, len(blob))
s3_key = await _blob_store.upload(
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum

View File

@@ -1,44 +1,25 @@
"""Billing routes: Stripe checkout, webhook, subscription management.
Subscription records are kept in-memory until Step 12 migrates them to
PostgreSQL (subscriptions table). Stripe calls are gracefully stubbed when
STRIPE_SECRET_KEY is not configured, allowing local development without keys.
Business logic lives in ``app.billing.stripe_service.StripeService``.
The route layer handles HTTP concerns (request parsing, response shaping)
and delegates everything else to the service singleton.
"""
from __future__ import annotations
from typing import Any
import stripe as stripe_lib
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from fastapi import APIRouter, Depends, Header, Request, status
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.config.settings import settings
from app.billing.stripe_service import stripe_service
from app.db import get_session
from app.schemas import BillingTier, UserProfile
router = APIRouter(prefix="/billing", tags=["billing"])
# In-memory subscriptions — replaced by PostgreSQL subscriptions table in Step 12
_subscriptions: dict[str, dict[str, Any]] = {} # user_id → subscription record
_TIER_PRICE_IDS: dict[str, str] = {
"pro": "price_pro_monthly", # replace with real Stripe price IDs
"power": "price_power_monthly",
"team": "price_team_monthly",
}
# ── Helpers ────────────────────────────────────────────────────────────
def _stripe_configured() -> bool:
return bool(settings.STRIPE_SECRET_KEY)
def _stripe() -> Any:
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
return stripe_lib
# ── Request bodies ─────────────────────────────────────────────────────
@@ -57,40 +38,15 @@ async def create_checkout(
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
"""
if body.tier == "free":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot create a checkout session for the free tier",
)
if _stripe_configured():
price_id = _TIER_PRICE_IDS.get(body.tier)
if not price_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unknown tier: {body.tier}",
)
s = _stripe()
session = s.checkout.Session.create(
payment_method_types=["card"],
mode="subscription",
line_items=[{"price": price_id, "quantity": 1}],
success_url=(
"https://app.adiuva.app/billing/success"
"?session_id={CHECKOUT_SESSION_ID}"
),
cancel_url="https://app.adiuva.app/billing/cancel",
metadata={"user_id": current_user.id, "tier": body.tier},
)
return {"checkout_url": session.url}
return {"checkout_url": "https://stripe.com/stub-checkout"}
url = stripe_service.create_checkout_session(current_user.id, body.tier)
return {"checkout_url": url}
@router.post("/webhook", response_model=dict)
async def stripe_webhook(
request: Request,
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Handle Stripe webhook events.
@@ -98,57 +54,17 @@ async def stripe_webhook(
Returns 200 immediately when Stripe is not configured (local dev).
"""
payload = await request.body()
if not _stripe_configured():
return {"ok": True}
try:
s = _stripe()
event = s.Webhook.construct_event(
payload, stripe_signature, settings.STRIPE_WEBHOOK_SECRET
)
except stripe_lib.error.SignatureVerificationError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid Stripe signature",
)
event_type: str = event["type"]
data: dict[str, Any] = event["data"]["object"]
if event_type == "checkout.session.completed":
user_id = data.get("metadata", {}).get("user_id")
tier = data.get("metadata", {}).get("tier", "free")
sub_id = data.get("subscription")
if user_id:
_subscriptions[user_id] = {
"tier": tier,
"stripe_subscription_id": sub_id,
"status": "active",
"current_period_end": None,
}
elif event_type == "customer.subscription.updated":
# TODO(Step12): look up user_id from stripe_customer_id in DB, then update tier
pass
elif event_type == "customer.subscription.deleted":
# TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free
pass
elif event_type == "invoice.payment_failed":
# TODO(Step12): flag subscription as past_due, notify user
pass
await stripe_service.handle_webhook(payload, stripe_signature, db)
return {"ok": True}
@router.get("/subscription", response_model=dict)
async def get_subscription(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""Return the current subscription info for the authenticated user."""
sub = _subscriptions.get(current_user.id)
sub = await stripe_service.get_subscription(current_user.id, db)
if sub is None:
return {
"tier": current_user.tier,
@@ -159,26 +75,11 @@ async def get_subscription(
return sub
@router.delete("/subscription", response_model=dict)
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
async def cancel_subscription(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Cancel the active subscription."""
sub = _subscriptions.get(current_user.id)
if sub is None or not sub.get("stripe_subscription_id"):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No active subscription found",
)
if _stripe_configured():
s = _stripe()
s.Subscription.cancel(sub["stripe_subscription_id"])
_subscriptions[current_user.id] = {
**sub,
"tier": "free",
"status": "canceled",
}
await stripe_service.cancel_subscription(current_user.id, db)
return {"ok": True}

View File

@@ -1,7 +1,8 @@
"""Plugins routes: browse and install plugins from the marketplace.
The catalog and installation records are kept in-memory as stubs.
Step 10 replaces these with PluginRegistry, RevenueShare, and the plugins DB table.
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes introduced
in Step 10. Step 12 will swap those services' in-memory stores for
PostgreSQL persistence.
"""
from __future__ import annotations
@@ -12,49 +13,12 @@ from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.config.settings import settings
from app.marketplace.plugin_registry import registry
from app.marketplace.revenue_share import revenue_share
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
router = APIRouter(prefix="/plugins", tags=["plugins"])
# ── In-memory catalog (Step 10 replaces with PluginRegistry + DB) ─────
_plugin_catalog: list[PluginManifest] = [
PluginManifest(
id="plugin-github-sync",
name="GitHub Sync",
description="Sync tasks with GitHub Issues and pull requests.",
version="1.0.0",
author="Adiuva",
permissions=["read:tasks", "write:tasks"],
category="productivity",
price_cents=0,
),
PluginManifest(
id="plugin-slack-notify",
name="Slack Notifier",
description="Post task and checkpoint updates to Slack channels.",
version="1.2.0",
author="Adiuva",
permissions=["read:tasks", "read:checkpoints"],
category="communication",
price_cents=499,
),
PluginManifest(
id="plugin-time-tracker",
name="Time Tracker",
description="Track time spent on tasks with automatic reporting.",
version="0.9.1",
author="Third Party",
permissions=["read:tasks", "write:tasks"],
category="productivity",
price_cents=999,
),
]
# plugin_id → set of user_ids who have installed it
_installations: dict[str, set[str]] = {}
# ── Tier gate ─────────────────────────────────────────────────────────
@@ -67,43 +31,12 @@ def _require_plugin_tier(user: UserProfile) -> None:
)
# ── Filter + sort helpers ──────────────────────────────────────────────
def _apply_filters(
plugins: list[PluginManifest],
category: str | None,
q: str | None,
) -> list[PluginManifest]:
result = plugins
if category:
result = [p for p in result if p.category == category]
if q:
q_lower = q.lower()
result = [
p for p in result
if q_lower in p.name.lower() or q_lower in p.description.lower()
]
return result
def _apply_sort(
plugins: list[PluginManifest],
sort: str,
) -> list[PluginManifest]:
if sort == "installs":
return sorted(plugins, key=lambda p: len(_installations.get(p.id, set())), reverse=True)
if sort == "rating":
# Placeholder until Step 10 introduces avg_rating from DB
return sorted(plugins, key=lambda p: -p.price_cents)
return plugins # "newest" = catalog insertion order
# ── Local detail schema ────────────────────────────────────────────────
class _PluginDetail(BaseModel):
plugin: PluginManifest
install_count: int
ratings: list[Any] # Step 10 populates from plugin_reviews table
ratings: list[Any] # Step 12 populates from plugin_reviews table
# ── Routes ────────────────────────────────────────────────────────────
@@ -118,9 +51,7 @@ async def list_plugins(
) -> PluginListResponse:
"""Browse the plugin marketplace. Requires Power tier or above."""
_require_plugin_tier(current_user)
filtered = _apply_filters(_plugin_catalog, category, q)
sorted_plugins = _apply_sort(filtered, sort)
return PluginListResponse(plugins=sorted_plugins, total=len(sorted_plugins), page=page)
return await registry.list_plugins(category=category, query=q, page=page, sort=sort)
@router.get("/{plugin_id}", response_model=_PluginDetail)
@@ -130,13 +61,13 @@ async def get_plugin(
) -> _PluginDetail:
"""Get full plugin details including install count. Requires Power tier or above."""
_require_plugin_tier(current_user)
plugin = next((p for p in _plugin_catalog if p.id == plugin_id), None)
if plugin is None:
entry = await registry.get_plugin(plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
return _PluginDetail(
plugin=plugin,
install_count=len(_installations.get(plugin_id, set())),
ratings=[], # Step 10 populates from plugin_reviews table
plugin=entry["manifest"],
install_count=entry["install_count"],
ratings=[], # Step 12 populates from plugin_reviews table
)
@@ -146,20 +77,21 @@ async def install_plugin(
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, Any]:
"""Install a plugin. Triggers Stripe Connect for paid plugins when configured.
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
Requires Power tier or above.
"""
_require_plugin_tier(current_user)
plugin = next((p for p in _plugin_catalog if p.id == plugin_id), None)
if plugin is None:
entry = await registry.get_plugin(plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
if plugin.price_cents > 0 and settings.STRIPE_SECRET_KEY:
# TODO(Step10): stripe.PaymentIntent.create with destination charge (70/30 split)
pass
await revenue_share.record_install(
plugin_id=plugin_id,
user_id=current_user.id,
amount_cents=entry["manifest"].price_cents,
)
_installations.setdefault(plugin_id, set()).add(current_user.id)
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
return {"ok": True, "download_url": download_url}
@@ -170,5 +102,5 @@ async def uninstall_plugin(
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, bool]:
"""Unregister a plugin installation."""
_installations.get(plugin_id, set()).discard(current_user.id)
await registry.record_uninstall(plugin_id)
return {"ok": True}

View File

@@ -14,6 +14,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
@@ -25,14 +26,6 @@ _blob_store = BlobStore()
# In-memory record metadata — replaced by PostgreSQL storage_records table in Step 12
_records: dict[str, dict[str, Any]] = {}
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
_TIER_STORAGE_LIMITS_GB: dict[str, int] = {
"free": 0,
"pro": 5,
"power": 25,
"team": -1, # unlimited
}
# ── Local response schemas ─────────────────────────────────────────────
@@ -51,18 +44,10 @@ class _RecordMeta(BaseModel):
# ── Helpers ────────────────────────────────────────────────────────────
def _check_quota(user_id: str, tier: str, additional_bytes: int) -> None:
def _check_quota(user_id: str, additional_bytes: int) -> None:
"""Raise HTTP 402 if adding ``additional_bytes`` would exceed the tier limit."""
limit_gb = _TIER_STORAGE_LIMITS_GB.get(tier, 0)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024**3
used = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id)
if used + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Storage quota exceeded for tier '{tier}'",
)
current = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id)
tier_manager.enforce_quota(user_id, current_bytes=current, additional_bytes=additional_bytes)
def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]:
@@ -83,7 +68,7 @@ async def create_record(
) -> _CreateResponse:
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
reject_if_tampered(body.blob, body.checksum)
_check_quota(current_user.id, current_user.tier, len(body.blob))
_check_quota(current_user.id, len(body.blob))
record_id = str(uuid.uuid4())
now = int(time.time() * 1000)
@@ -159,7 +144,7 @@ async def update_record(
delta = len(body.blob) - record["size_bytes"]
if delta > 0:
_check_quota(current_user.id, current_user.tier, delta)
_check_quota(current_user.id, delta)
s3_key = await _blob_store.upload(
current_user.id, record["table"], record_id, body.blob, body.checksum

View File

@@ -0,0 +1,4 @@
from app.billing.stripe_service import stripe_service
from app.billing.tier_manager import tier_manager
__all__ = ["stripe_service", "tier_manager"]

View File

@@ -0,0 +1,256 @@
"""Stripe service: checkout sessions, webhook handling, subscription management.
Subscription records are persisted in the PostgreSQL ``subscriptions`` table.
All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not
configured, enabling local development without live credentials.
"""
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
import stripe as stripe_lib
from fastapi import HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
# Stripe price IDs per tier — replace with real IDs in production .env
TIER_PRICE_IDS: dict[str, str] = {
"pro": "price_pro_monthly",
"power": "price_power_monthly",
"team": "price_team_monthly",
}
class StripeService:
"""Wraps all Stripe interactions and owns subscription persistence."""
# ── Internal helpers ────────────────────────────────────────────────
def _configured(self) -> bool:
return bool(settings.STRIPE_SECRET_KEY)
def _client(self) -> Any:
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
return stripe_lib
# ── Public API ──────────────────────────────────────────────────────
def create_checkout_session(
self,
user_id: str,
tier: str,
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
cancel_url: str = "https://app.adiuva.app/billing/cancel",
) -> str:
"""Create a Stripe checkout session and return the URL.
Returns a stub URL when Stripe is not configured.
Raises ``HTTP 400`` for the free tier or an unknown tier.
"""
if tier == "free":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot create a checkout session for the free tier",
)
price_id = TIER_PRICE_IDS.get(tier)
if not price_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unknown tier: {tier}",
)
if not self._configured():
return "https://stripe.com/stub-checkout"
s = self._client()
session = s.checkout.Session.create(
payment_method_types=["card"],
mode="subscription",
line_items=[{"price": price_id, "quantity": 1}],
success_url=success_url,
cancel_url=cancel_url,
metadata={"user_id": user_id, "tier": tier},
)
return session.url
async def handle_webhook(
self,
payload: bytes,
sig_header: str,
db: AsyncSession,
) -> None:
"""Process a Stripe webhook event.
Verifies the signature, then dispatches on event type.
Raises ``HTTP 400`` on signature mismatch.
No-ops when Stripe is not configured.
"""
if not self._configured():
return
try:
s = self._client()
event = s.Webhook.construct_event(
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
)
except stripe_lib.error.SignatureVerificationError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid Stripe signature",
)
event_type: str = event["type"]
data: dict[str, Any] = event["data"]["object"]
if event_type == "checkout.session.completed":
user_id = data.get("metadata", {}).get("user_id")
tier = data.get("metadata", {}).get("tier", "free")
sub_id = data.get("subscription")
period_end_ts = data.get("current_period_end")
period_end = (
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
if period_end_ts
else None
)
if user_id:
await self._upsert_subscription(
db, user_id, sub_id, tier, "active", period_end
)
elif event_type == "customer.subscription.updated":
sub_id = data.get("id")
new_status = data.get("status", "active")
period_end_ts = data.get("current_period_end")
period_end = (
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
if period_end_ts
else None
)
if sub_id:
await self._update_subscription_by_stripe_id(
db, sub_id, status=new_status, current_period_end=period_end
)
elif event_type == "customer.subscription.deleted":
sub_id = data.get("id")
if sub_id:
await self._update_subscription_by_stripe_id(
db, sub_id, tier="free", status="canceled"
)
elif event_type == "invoice.payment_failed":
sub_id = data.get("subscription")
if sub_id:
await self._update_subscription_by_stripe_id(
db, sub_id, status="past_due"
)
await db.commit()
async def get_subscription(
self, user_id: str, db: AsyncSession
) -> dict[str, Any] | None:
"""Return the subscription record for ``user_id``, or ``None`` if absent."""
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription).where(Subscription.user_id == user_id)
)
sub = result.scalar_one_or_none()
if sub is None:
return None
return {
"tier": sub.tier,
"stripe_subscription_id": sub.stripe_subscription_id,
"status": sub.status,
"current_period_end": (
int(sub.current_period_end.timestamp() * 1000)
if sub.current_period_end
else None
),
}
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
"""Cancel the user's Stripe subscription and downgrade them to free.
Raises ``HTTP 404`` when no active subscription exists.
"""
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription).where(Subscription.user_id == user_id)
)
sub = result.scalar_one_or_none()
if sub is None or not sub.stripe_subscription_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No active subscription found",
)
if self._configured():
s = self._client()
s.Subscription.cancel(sub.stripe_subscription_id)
sub.tier = "free"
sub.status = "canceled"
await db.commit()
# ── Private DB helpers ───────────────────────────────────────────────
async def _upsert_subscription(
self,
db: AsyncSession,
user_id: str,
stripe_subscription_id: str | None,
tier: str,
sub_status: str,
current_period_end: datetime | None,
) -> None:
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription).where(Subscription.user_id == user_id)
)
sub = result.scalar_one_or_none()
if sub is None:
sub = Subscription(user_id=user_id)
db.add(sub)
sub.stripe_subscription_id = stripe_subscription_id
sub.tier = tier
sub.status = sub_status
sub.current_period_end = current_period_end
async def _update_subscription_by_stripe_id(
self,
db: AsyncSession,
stripe_subscription_id: str,
*,
tier: str | None = None,
status: str | None = None,
current_period_end: datetime | None = None,
) -> None:
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription).where(
Subscription.stripe_subscription_id == stripe_subscription_id
)
)
sub = result.scalar_one_or_none()
if sub is None:
return
if tier is not None:
sub.tier = tier
if status is not None:
sub.status = status
if current_period_end is not None:
sub.current_period_end = current_period_end
# Module-level singleton shared across the app.
stripe_service = StripeService()

189
app/billing/tier_manager.py Normal file
View File

@@ -0,0 +1,189 @@
"""Tier manager: feature matrix and quota enforcement.
``TierManager`` is the single source of truth for what each billing tier
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
Quota-enforcement helpers take ``tier`` directly — the caller already has it
from ``current_user.tier`` (provided by ``get_current_user``).
"""
from __future__ import annotations
from typing import Any
from fastapi import HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.schemas import BillingTier
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
FEATURES: dict[str, dict[str, Any]] = {
"free": {
"agents": 3,
"batch_active": 2,
"cloud_storage_gb": 0,
"backup_gb": 0,
"providers": 1,
"batch_builder": False,
"plugin_marketplace": False,
"sso": False,
},
"pro": {
"agents": -1, # unlimited
"batch_active": 10,
"cloud_storage_gb": 5,
"backup_gb": 5,
"providers": -1,
"batch_builder": False,
"plugin_marketplace": False,
"sso": False,
},
"power": {
"agents": -1,
"batch_active": -1, # unlimited
"cloud_storage_gb": 25,
"backup_gb": 25,
"providers": -1,
"batch_builder": True,
"plugin_marketplace": True,
"sso": False,
},
"team": {
"agents": -1,
"batch_active": -1,
"cloud_storage_gb": -1, # unlimited
"backup_gb": -1, # unlimited
"providers": -1,
"batch_builder": True,
"plugin_marketplace": True,
"sso": True,
},
}
# Requests-per-minute limit per tier.
RATE_LIMITS: dict[str, int] = {
"free": 20,
"pro": 60,
"power": 120,
"team": 200,
}
class TierManager:
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
# ── Tier lookup ─────────────────────────────────────────────────────
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
"""Return the current billing tier for ``user_id`` from the DB.
Falls back to ``'free'`` when no subscription row exists.
"""
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
tier: str | None = result.scalar_one_or_none()
if tier is None or tier not in FEATURES:
return "free"
return tier # type: ignore[return-value]
# ── Feature access ───────────────────────────────────────────────────
def check_feature(self, tier: BillingTier, feature: str) -> bool:
"""Return ``True`` if ``tier`` has ``feature`` enabled.
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
"""
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
if value is None:
return False
if isinstance(value, bool):
return value
return value != 0
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
if not self.check_feature(tier, feature):
detail = (
f"Feature '{feature}' requires {tier_name} tier or above."
if tier_name
else f"Feature '{feature}' is not available on your current tier."
)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
# ── Rate limiting ────────────────────────────────────────────────────
def get_rate_limit(self, tier: BillingTier) -> int:
"""Return the requests-per-minute limit for ``tier``."""
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
# ── Storage quota ────────────────────────────────────────────────────
def enforce_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
``tier`` is the caller's current tier (from ``current_user.tier``).
``current_bytes`` is the total bytes already stored (queried by caller).
"""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Cloud storage is not available on the '{tier}' tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Storage quota exceeded for tier '{tier}'",
)
def enforce_backup_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
limit_gb: int = FEATURES[tier]["backup_gb"]
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup is not available on the '{tier}' tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup quota exceeded for tier '{tier}'",
)
def check_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> bool:
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
return False
if limit_gb == -1:
return True
limit_bytes = limit_gb * 1024 ** 3
return current_bytes + additional_bytes <= limit_bytes
# Module-level singleton shared across the app.
tier_manager = TierManager()

40
app/db.py Normal file
View File

@@ -0,0 +1,40 @@
"""Database engine, session factory, and base model.
All app code uses the async SQLAlchemy API. Alembic migrations use the
synchronous psycopg2 URL for the CLI (see alembic/env.py).
Usage in routes:
from app.db import get_session
from sqlalchemy.ext.asyncio import AsyncSession
async def my_route(db: AsyncSession = Depends(get_session)):
result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none()
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from app.config.settings import settings
engine = create_async_engine(
settings.DATABASE_URL,
pool_pre_ping=True,
echo=settings.ENV == "dev",
)
async_session = async_sessionmaker(engine, expire_on_commit=False)
class Base(DeclarativeBase):
"""Shared declarative base for all ORM models."""
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""FastAPI dependency that yields an async DB session per request."""
async with async_session() as session:
yield session

View File

@@ -3,6 +3,8 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.middleware.rate_limit import TierRateLimitMiddleware
from app.api.middleware.sanitizer import SanitizerMiddleware
from app.config.settings import settings
@@ -14,7 +16,9 @@ async def lifespan(app: FastAPI):
yield
# Shutdown: nothing to clean up for now
# Shutdown: dispose SQLAlchemy connection pool
from app.db import engine
await engine.dispose()
def create_app() -> FastAPI:
@@ -33,6 +37,11 @@ def create_app() -> FastAPI:
allow_methods=["*"],
allow_headers=["*"],
)
# Middleware stack (Starlette inserts at position 0, so last-added = outermost).
# Request flow: TierRateLimit → Sanitizer → CORS → Router
# Response flow: Router → CORS → Sanitizer → TierRateLimit
app.add_middleware(SanitizerMiddleware)
app.add_middleware(TierRateLimitMiddleware)
from app.api.routes import auth, backup, billing, chat, plans, plugins, storage, vectors

View File

@@ -0,0 +1,7 @@
"""Plugin marketplace package.
Three service classes introduced in Step 10:
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
- ``ReviewQueue`` — approval workflow + security checklist
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
"""

View File

@@ -0,0 +1,211 @@
"""Plugin catalog registry.
Maintains the authoritative list of plugins, their review status, and
aggregate install counts. Storage is in-memory until Step 12 migrates to
the ``plugins`` PostgreSQL table.
Module-level singleton::
from app.marketplace.plugin_registry import registry
"""
from __future__ import annotations
import copy
import time
import uuid
from typing import Any, Literal
from app.schemas import PluginListResponse, PluginManifest
# ── Pre-seeded approved plugins (mirrors the Step 8 stub catalog) ─────
_SEED_PLUGINS: list[dict[str, Any]] = [
{
"manifest": PluginManifest(
id="plugin-github-sync",
name="GitHub Sync",
description="Sync tasks with GitHub Issues and pull requests.",
version="1.0.0",
author="Adiuva",
permissions=["read:tasks", "write:tasks"],
category="productivity",
price_cents=0,
),
"status": "approved",
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
"rejection_reason": None,
"submitted_at": int(time.time()),
},
{
"manifest": PluginManifest(
id="plugin-slack-notify",
name="Slack Notifier",
description="Post task and checkpoint updates to Slack channels.",
version="1.2.0",
author="Adiuva",
permissions=["read:tasks", "read:checkpoints"],
category="communication",
price_cents=499,
),
"status": "approved",
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
"rejection_reason": None,
"submitted_at": int(time.time()),
},
{
"manifest": PluginManifest(
id="plugin-time-tracker",
name="Time Tracker",
description="Track time spent on tasks with automatic reporting.",
version="0.9.1",
author="Third Party",
permissions=["read:tasks", "write:tasks"],
category="productivity",
price_cents=999,
),
"status": "approved",
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
"install_count": 0,
"avg_rating": 0.0,
"rejection_reason": None,
"submitted_at": int(time.time()),
},
]
_PAGE_SIZE = 20
class PluginRegistry:
"""In-process plugin catalog.
All mutating methods are ``async`` to make the future DB swap transparent
to callers.
"""
def __init__(self) -> None:
# plugin_id → entry dict (deep-copied so each instance is independent)
self._catalog: dict[str, dict[str, Any]] = {
e["manifest"].id: copy.deepcopy(e) for e in _SEED_PLUGINS
}
# ── Queries ──────────────────────────────────────────────────────
async def list_plugins(
self,
category: str | None = None,
query: str | None = None,
page: int = 1,
sort: Literal["rating", "installs", "newest"] = "newest",
) -> PluginListResponse:
"""Return a page of approved plugins, optionally filtered and sorted."""
entries = [e for e in self._catalog.values() if e["status"] == "approved"]
if category:
entries = [e for e in entries if e["manifest"].category == category]
if query:
q_lower = query.lower()
entries = [
e
for e in entries
if q_lower in e["manifest"].name.lower()
or q_lower in e["manifest"].description.lower()
]
if sort == "installs":
entries = sorted(entries, key=lambda e: e["install_count"], reverse=True)
elif sort == "rating":
entries = sorted(entries, key=lambda e: e["avg_rating"], reverse=True)
# "newest" = catalog insertion order (dict preserves insertion in Python 3.7+)
total = len(entries)
start = (page - 1) * _PAGE_SIZE
page_entries = entries[start : start + _PAGE_SIZE]
return PluginListResponse(
plugins=[e["manifest"] for e in page_entries],
total=total,
page=page,
)
async def get_plugin(self, plugin_id: str) -> dict[str, Any] | None:
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
entry = self._catalog.get(plugin_id)
if entry is None:
return None
return {
"manifest": entry["manifest"],
"status": entry["status"],
"install_count": entry["install_count"],
"avg_rating": entry["avg_rating"],
}
# ── Mutations ────────────────────────────────────────────────────
async def submit_plugin(
self,
manifest: PluginManifest,
package_s3_key: str,
) -> str:
"""Add *manifest* to the catalog with ``status='pending_review'``.
Returns the plugin_id. If a plugin with the same id already exists
it is overwritten (re-submission after rejection).
"""
plugin_id = manifest.id or str(uuid.uuid4())
self._catalog[plugin_id] = {
"manifest": manifest,
"status": "pending_review",
"s3_package_key": package_s3_key,
"install_count": 0,
"avg_rating": 0.0,
"rejection_reason": None,
"submitted_at": int(time.time()),
}
return plugin_id
async def approve_plugin(self, plugin_id: str) -> None:
"""Set *plugin_id* status to ``'approved'``.
Raises ``KeyError`` if the plugin is not found.
"""
if plugin_id not in self._catalog:
raise KeyError(f"Plugin not found: {plugin_id}")
self._catalog[plugin_id]["status"] = "approved"
self._catalog[plugin_id]["rejection_reason"] = None
async def reject_plugin(self, plugin_id: str, reason: str) -> None:
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
Raises ``KeyError`` if the plugin is not found.
"""
if plugin_id not in self._catalog:
raise KeyError(f"Plugin not found: {plugin_id}")
self._catalog[plugin_id]["status"] = "rejected"
self._catalog[plugin_id]["rejection_reason"] = reason
async def record_install(self, plugin_id: str) -> None:
"""Increment the install count for *plugin_id* (no-op if not found)."""
if plugin_id in self._catalog:
self._catalog[plugin_id]["install_count"] += 1
async def record_uninstall(self, plugin_id: str) -> None:
"""Decrement the install count for *plugin_id*, floored at 0."""
if plugin_id in self._catalog:
current = self._catalog[plugin_id]["install_count"]
self._catalog[plugin_id]["install_count"] = max(0, current - 1)
# ── Internal helpers used by ReviewQueue ─────────────────────────
def _get_pending_entries(self) -> list[dict[str, Any]]:
"""Return all entries with status='pending_review' (synchronous helper)."""
return [e for e in self._catalog.values() if e["status"] == "pending_review"]
# Module-level singleton
registry = PluginRegistry()

View File

@@ -0,0 +1,127 @@
"""Plugin review workflow.
Manages the approval queue for newly submitted plugins and enforces a
security checklist before any plugin is made visible in the marketplace.
Module-level singleton::
from app.marketplace.plugin_review import review_queue
"""
from __future__ import annotations
import re
import time
from typing import Any, Literal
from app.marketplace.plugin_registry import registry
from app.schemas import PluginManifest
# ── Security policy ───────────────────────────────────────────────────
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
{
"read:tasks",
"write:tasks",
"read:projects",
"write:projects",
"read:notes",
"write:notes",
"read:checkpoints",
"write:checkpoints",
"read:calendar",
"write:calendar",
}
)
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
def validate_manifest(manifest: PluginManifest) -> None:
"""Enforce the plugin security checklist.
Raises:
``ValueError`` on the first violation found. Callers should catch
this and return HTTP 422 / reject the submission.
Checks:
1. Plugin id matches ``^[a-z0-9-]+$``
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
3. No manifest field contains raw binary data
"""
if not _PLUGIN_ID_RE.match(manifest.id):
raise ValueError(
f"Invalid plugin id format: '{manifest.id}'. "
"Only lowercase letters, digits, and hyphens are allowed."
)
for perm in manifest.permissions:
if perm not in ALLOWED_PERMISSIONS:
raise ValueError(
f"Unknown permission: '{perm}'. "
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
)
for field_name, value in manifest.model_dump().items():
if isinstance(value, (bytes, bytearray)):
raise ValueError(
f"Binary content is not allowed in manifest field '{field_name}'."
)
class ReviewQueue:
"""Approval queue for pending plugin submissions.
Delegates status changes to the shared ``PluginRegistry`` singleton so
there is a single source of truth for plugin state.
"""
def __init__(self) -> None:
# Completed reviews — Step 12 stores in plugin_reviews table
self._reviews: list[dict[str, Any]] = []
async def get_pending(self) -> list[dict[str, Any]]:
"""Return all plugins currently awaiting review.
Each item is ``{plugin_id, manifest, submitted_at}``.
"""
entries = registry._get_pending_entries()
return [
{
"plugin_id": e["manifest"].id,
"manifest": e["manifest"],
"submitted_at": e["submitted_at"],
}
for e in entries
]
async def submit_review(
self,
plugin_id: str,
reviewer_id: str,
decision: Literal["approved", "rejected"],
notes: str = "",
) -> None:
"""Record a review decision and update the plugin's status.
Raises:
``KeyError`` if *plugin_id* is not found in the registry.
"""
if decision == "approved":
await registry.approve_plugin(plugin_id)
else:
await registry.reject_plugin(plugin_id, reason=notes)
self._reviews.append(
{
"plugin_id": plugin_id,
"reviewer_id": reviewer_id,
"decision": decision,
"notes": notes,
"reviewed_at": int(time.time()),
}
)
# Module-level singleton
review_queue = ReviewQueue()

View File

@@ -0,0 +1,205 @@
"""Revenue share tracking and Stripe Connect payouts.
Records every plugin installation as a revenue event and facilitates
70 % / 30 % payouts to developers via Stripe Connect. Storage is
in-memory until Step 12 migrates to the ``revenue_events`` table.
Module-level singleton::
from app.marketplace.revenue_share import revenue_share
"""
from __future__ import annotations
import logging
import time
from typing import Any
import stripe as stripe_lib
from app.config.settings import settings
from app.marketplace.plugin_registry import registry
logger = logging.getLogger(__name__)
# ── Revenue split constants ───────────────────────────────────────────
DEVELOPER_SHARE: float = 0.70
PLATFORM_SHARE: float = 0.30
class RevenueShare:
"""Records installation revenue events and coordinates developer payouts.
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
is not configured, consistent with the rest of the billing layer.
"""
def __init__(self) -> None:
# Step 12 replaces with revenue_events DB table
self._events: list[dict[str, Any]] = []
# ── Helpers ──────────────────────────────────────────────────────
@staticmethod
def _stripe_configured() -> bool:
return bool(settings.STRIPE_SECRET_KEY)
@staticmethod
def _stripe() -> Any:
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
return stripe_lib
# ── Core operations ──────────────────────────────────────────────
async def record_install(
self,
plugin_id: str,
user_id: str,
amount_cents: int,
) -> None:
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
For free plugins (``amount_cents == 0``) no payment is initiated but
the event is still recorded for analytics.
For paid plugins the developer receives 70 % via a Stripe Connect
destination charge. If Stripe is not configured or the charge fails
the installation still succeeds (the event is recorded and the install
count is incremented) — a warning is logged for monitoring.
"""
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
stripe_transfer_id: str | None = None
if amount_cents > 0 and self._stripe_configured():
plugin_entry = registry._catalog.get(plugin_id)
developer_stripe_account: str | None = None
if plugin_entry:
# Step 12: look up developer's Stripe account from DB
# For now, the author field is used as a placeholder key.
developer_stripe_account = None # no real account yet
if developer_stripe_account:
try:
s = self._stripe()
transfer = s.Transfer.create(
amount=developer_share_cents,
currency="eur",
destination=developer_stripe_account,
description=f"Revenue share for plugin {plugin_id}",
metadata={"plugin_id": plugin_id, "user_id": user_id},
)
stripe_transfer_id = transfer["id"]
except Exception as exc:
logger.warning(
"Stripe Connect transfer failed for plugin %s: %s",
plugin_id,
exc,
)
else:
logger.debug(
"No Stripe account on file for plugin %s developer; "
"skipping transfer.",
plugin_id,
)
self._events.append(
{
"plugin_id": plugin_id,
"user_id": user_id,
"amount_cents": amount_cents,
"developer_share_cents": developer_share_cents,
"stripe_transfer_id": stripe_transfer_id,
"paid_at": None,
"created_at": int(time.time()),
}
)
await registry.record_install(plugin_id)
async def get_earnings(
self,
developer_id: str,
period: str | None = None,
) -> dict[str, Any]:
"""Return aggregated earnings for *developer_id*.
``period`` is an optional ``YYYY-MM`` string to restrict the window.
Returns::
{
"developer_id": str,
"period": str | None,
"total_installs": int,
"total_revenue_cents": int,
"developer_share_cents": int,
}
"""
# Find plugin ids belonging to this developer
developer_plugin_ids: set[str] = {
pid
for pid, entry in registry._catalog.items()
if entry["manifest"].author == developer_id
}
events = [e for e in self._events if e["plugin_id"] in developer_plugin_ids]
if period:
# Filter by YYYY-MM prefix of the created_at timestamp
events = [
e
for e in events
if time.strftime("%Y-%m", time.gmtime(e["created_at"])) == period
]
return {
"developer_id": developer_id,
"period": period,
"total_installs": len(events),
"total_revenue_cents": sum(e["amount_cents"] for e in events),
"developer_share_cents": sum(e["developer_share_cents"] for e in events),
}
async def payout_developer(self, plugin_id: str, period: str) -> None:
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
Marks processed events with ``paid_at`` timestamp.
Stubs gracefully when Stripe is not configured.
"""
unpaid = [
e
for e in self._events
if e["plugin_id"] == plugin_id
and e["paid_at"] is None
and time.strftime("%Y-%m", time.gmtime(e["created_at"])) == period
]
total_dev_share = sum(e["developer_share_cents"] for e in unpaid)
if total_dev_share <= 0 or not unpaid:
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
return
if self._stripe_configured():
plugin_entry = registry._catalog.get(plugin_id)
developer_stripe_account: str | None = None # Step 12: fetch from DB
if plugin_entry and developer_stripe_account:
try:
s = self._stripe()
s.Transfer.create(
amount=total_dev_share,
currency="eur",
destination=developer_stripe_account,
description=f"Payout for plugin {plugin_id} period {period}",
)
except Exception as exc:
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
return
paid_ts = int(time.time())
for event in unpaid:
event["paid_at"] = paid_ts
# Module-level singleton
revenue_share = RevenueShare()

269
app/models.py Normal file
View File

@@ -0,0 +1,269 @@
"""SQLAlchemy ORM models for all persistent tables.
Only auth, billing, storage metadata, and marketplace data live here.
User content (notes, tasks, etc.) is NEVER persisted server-side —
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
Table inventory:
users — account credentials + tier
refresh_tokens — hashed refresh token store
subscriptions — Stripe subscription records
storage_records — S3 blob metadata (no plaintext)
backup_metadata — encrypted backup manifests
plugins — marketplace plugin catalog
plugin_installations — per-user install records
plugin_reviews — admin review decisions
revenue_events — Stripe Connect 70/30 split ledger
"""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from sqlalchemy import (
BigInteger,
Boolean,
DateTime,
Enum,
Float,
ForeignKey,
Integer,
String,
Text,
UniqueConstraint,
func,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db import Base
# ── Helpers ──────────────────────────────────────────────────────────────
def _uuid() -> str:
return str(uuid.uuid4())
def _now() -> datetime:
return datetime.now(timezone.utc)
# ── Enum types ────────────────────────────────────────────────────────────
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
# ── Models ────────────────────────────────────────────────────────────────
class User(Base):
__tablename__ = "users"
id: Mapped[str] = mapped_column(
UUID(as_uuid=False), primary_key=True, default=_uuid
)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
refresh_tokens: Mapped[list[RefreshToken]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)
subscription: Mapped[Subscription | None] = relationship(
back_populates="user", uselist=False, cascade="all, delete-orphan"
)
class RefreshToken(Base):
__tablename__ = "refresh_tokens"
id: Mapped[str] = mapped_column(
UUID(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
user: Mapped[User] = relationship(back_populates="refresh_tokens")
class Subscription(Base):
__tablename__ = "subscriptions"
id: Mapped[str] = mapped_column(
UUID(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, unique=True, index=True
)
stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
status: Mapped[str] = mapped_column(String(50), nullable=False, default="free")
current_period_end: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
user: Mapped[User] = relationship(back_populates="subscription")
class StorageRecord(Base):
__tablename__ = "storage_records"
id: Mapped[str] = mapped_column(
UUID(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class BackupMetadata(Base):
__tablename__ = "backup_metadata"
id: Mapped[str] = mapped_column(
UUID(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
version: Mapped[int] = mapped_column(Integer, nullable=False)
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
class Plugin(Base):
__tablename__ = "plugins"
id: Mapped[str] = mapped_column(String(255), primary_key=True)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
# nullable until developer account system is built
author_id: Mapped[str | None] = mapped_column(
UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
submitted_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
installations: Mapped[list[PluginInstallation]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
reviews: Mapped[list[PluginReview]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
revenue_events: Mapped[list[RevenueEvent]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
class PluginInstallation(Base):
__tablename__ = "plugin_installations"
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
id: Mapped[str] = mapped_column(
UUID(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
installed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="installations")
class PluginReview(Base):
__tablename__ = "plugin_reviews"
id: Mapped[str] = mapped_column(
UUID(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
reviewer_id: Mapped[str | None] = mapped_column(
UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
reviewed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
class RevenueEvent(Base):
__tablename__ = "revenue_events"
id: Mapped[str] = mapped_column(
UUID(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")

304
tests/test_middleware.py Normal file
View File

@@ -0,0 +1,304 @@
"""Tests for Step 9 middleware: auth, rate limiting, and sanitizer.
Auth tests: validated via GET /api/v1/auth/me (requires a Bearer JWT).
Rate limit: use unique user UUIDs per test so windows are independent;
the free-tier threshold (20 req/min) is exercised directly.
Sanitizer: the orchestrator is mocked to inject controlled prompt
fragments, and the chat endpoint response body is inspected.
"""
from __future__ import annotations
import time
import uuid
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from jose import jwt
from app.config.settings import settings
from app.main import app
from app.schemas import ChatResponse
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_CHAT_BODY = {
"message": "hello",
"context": {
"user_profile": {},
"relevant_documents": [],
"recent_tasks": [],
"conversation_history": [],
},
"execution_mode": "direct",
}
def _make_jwt(
*,
user_id: str | None = None,
email: str = "test@example.com",
tier: str = "free",
exp_offset: int = 3600,
secret: str | None = None,
include_sub: bool = True,
) -> str:
"""Mint a test JWT signed with the configured (or custom) secret."""
uid = user_id or str(uuid.uuid4())
now = int(time.time())
payload: dict = {
"email": email,
"tier": tier,
"exp": now + exp_offset,
"iat": now,
}
if include_sub:
payload["sub"] = uid
key = secret or settings.JWT_SECRET
return jwt.encode(payload, key, algorithm=settings.JWT_ALGORITHM)
def _auth_header(token: str) -> dict[str, str]:
return {"Authorization": f"Bearer {token}"}
# ---------------------------------------------------------------------------
# Auth middleware
# ---------------------------------------------------------------------------
class TestAuthMiddleware:
"""Tests exercised via GET /api/v1/auth/me."""
def test_valid_token_returns_profile(self) -> None:
uid = str(uuid.uuid4())
token = _make_jwt(user_id=uid, email="alice@example.com", tier="pro")
with TestClient(app) as client:
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 200
data = resp.json()
assert data["id"] == uid
assert data["email"] == "alice@example.com"
assert data["tier"] == "pro"
def test_missing_token_returns_401(self) -> None:
with TestClient(app) as client:
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
def test_expired_token_returns_401(self) -> None:
token = _make_jwt(exp_offset=-1) # already expired
with TestClient(app) as client:
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 401
def test_wrong_signature_returns_401(self) -> None:
token = _make_jwt(secret="totally-wrong-secret")
with TestClient(app) as client:
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 401
def test_missing_sub_claim_returns_401(self) -> None:
token = _make_jwt(include_sub=False)
with TestClient(app) as client:
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 401
def test_malformed_token_returns_401(self) -> None:
with TestClient(app) as client:
resp = client.get(
"/api/v1/auth/me", headers={"Authorization": "Bearer not.a.jwt"}
)
assert resp.status_code == 401
# ---------------------------------------------------------------------------
# Rate limiter middleware
# ---------------------------------------------------------------------------
class TestRateLimitMiddleware:
"""Each test uses a fresh unique user_id so windows never collide."""
def _unique_token(self, tier: str = "free") -> str:
return _make_jwt(user_id=str(uuid.uuid4()), tier=tier)
def test_free_tier_allows_up_to_20_requests(self) -> None:
token = self._unique_token("free")
with TestClient(app) as client:
for _ in range(20):
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 200
def test_free_tier_blocks_21st_request(self) -> None:
token = self._unique_token("free")
with TestClient(app) as client:
for _ in range(20):
client.get("/api/v1/auth/me", headers=_auth_header(token))
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 429
def test_429_includes_retry_after_header(self) -> None:
token = self._unique_token("free")
with TestClient(app) as client:
for _ in range(20):
client.get("/api/v1/auth/me", headers=_auth_header(token))
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 429
assert "retry-after" in resp.headers
retry_after = int(resp.headers["retry-after"])
assert retry_after >= 1
def test_429_response_has_detail_field(self) -> None:
token = self._unique_token("free")
with TestClient(app) as client:
for _ in range(20):
client.get("/api/v1/auth/me", headers=_auth_header(token))
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 429
assert "detail" in resp.json()
def test_pro_tier_allows_60_requests(self) -> None:
token = self._unique_token("pro")
with TestClient(app) as client:
# Sample: first 60 succeed, 61st is blocked.
for _ in range(60):
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 200
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 429
def test_independent_users_have_separate_windows(self) -> None:
token_a = self._unique_token("free")
token_b = self._unique_token("free")
with TestClient(app) as client:
# Exhaust user A's quota.
for _ in range(20):
client.get("/api/v1/auth/me", headers=_auth_header(token_a))
assert (
client.get(
"/api/v1/auth/me", headers=_auth_header(token_a)
).status_code
== 429
)
# User B's quota is untouched.
resp_b = client.get("/api/v1/auth/me", headers=_auth_header(token_b))
assert resp_b.status_code == 200
def test_exempt_path_register_never_rate_limited(self) -> None:
"""POST /auth/register is exempt — 25 calls should never return 429."""
with TestClient(app) as client:
for i in range(25):
resp = client.post(
"/api/v1/auth/register",
json={"email": f"user{i}_{uuid.uuid4()}@example.com", "password": "pw"},
)
# 201 on first, 409 on email collision — but never 429.
assert resp.status_code != 429
def test_exempt_path_login_never_rate_limited(self) -> None:
"""POST /auth/login is exempt — multiple failed attempts are not rate-limited."""
with TestClient(app) as client:
for _ in range(25):
resp = client.post(
"/api/v1/auth/login",
json={"email": "nosuchuser@example.com", "password": "wrong"},
)
assert resp.status_code != 429
def test_exempt_path_health_never_rate_limited(self) -> None:
with TestClient(app) as client:
for _ in range(25):
resp = client.get("/api/v1/health")
assert resp.status_code == 200
# ---------------------------------------------------------------------------
# Sanitizer middleware
# ---------------------------------------------------------------------------
class TestSanitizerMiddleware:
"""Mock ``orchestrate`` to inject controlled strings into chat responses."""
_CHAT_PATH = "/api/v1/chat"
def _token(self) -> str:
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
def _post_chat(self, client: TestClient, response_text: str) -> dict:
mock_response = ChatResponse(response=response_text, actions=[])
with patch(
"app.api.routes.chat.orchestrate",
new_callable=AsyncMock,
return_value=mock_response,
):
resp = client.post(
self._CHAT_PATH,
json=_CHAT_BODY,
headers=_auth_header(self._token()),
)
assert resp.status_code == 200
return resp.json()
def test_clean_response_passes_through_unchanged(self) -> None:
with TestClient(app) as client:
data = self._post_chat(client, "Sure, I created the task for you.")
assert data["response"] == "Sure, I created the task for you."
def test_strips_system_prompt_opener(self) -> None:
with TestClient(app) as client:
data = self._post_chat(
client, "You are an intent classifier. Route to task_agent."
)
assert "You are" not in data["response"]
assert "[REDACTED]" in data["response"]
def test_strips_known_fingerprint(self) -> None:
with TestClient(app) as client:
data = self._post_chat(
client, "Respond with just the agent name and nothing else."
)
assert data["response"] == "[REDACTED]"
def test_strips_tool_schema_fragment(self) -> None:
with TestClient(app) as client:
data = self._post_chat(
client, 'Here is the schema: {"type": "function", "name": "foo"}'
)
assert '"type": "function"' not in data["response"]
def test_strips_reasoning_tag(self) -> None:
with TestClient(app) as client:
data = self._post_chat(
client, "<thinking>I should route this to calendar_agent</thinking>Done."
)
assert "<thinking>" not in data["response"]
assert "[REDACTED]" in data["response"]
def test_strips_available_agents_fragment(self) -> None:
with TestClient(app) as client:
data = self._post_chat(
client, "Available agents: task_agent, calendar_agent"
)
assert "[REDACTED]" in data["response"]
def test_sanitizer_does_not_activate_for_non_chat_path(self) -> None:
"""GET /api/v1/plans/playbook should pass through the sanitizer untouched."""
token = self._token()
with TestClient(app) as client:
resp = client.get(
"/api/v1/plans/playbook",
headers=_auth_header(token),
)
# The sanitizer should not interfere — just check it returns something
# (200 or whatever the route returns; we only care it's not broken).
assert resp.status_code in (200, 401, 403, 404)
def test_sanitizer_preserves_empty_response(self) -> None:
with TestClient(app) as client:
data = self._post_chat(client, "")
assert data["response"] == ""

387
tests/test_plugins.py Normal file
View File

@@ -0,0 +1,387 @@
"""Tests for Step 10: Plugin Marketplace.
Covers:
- PluginRegistry: catalog management, filtering, sorting, install counts
- ReviewQueue: pending queue, review decisions, manifest security checklist
- RevenueShare: install event recording, earnings aggregation
- Route integration: tier gate, list/get/install/uninstall via TestClient
"""
from __future__ import annotations
import time
import uuid
import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from jose import jwt
from unittest.mock import patch
from app.config.settings import settings
from app.main import app
from app.marketplace.plugin_registry import PluginRegistry
from app.marketplace.plugin_review import ReviewQueue, validate_manifest
from app.marketplace.revenue_share import RevenueShare
from app.schemas import PluginManifest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_jwt(tier: str = "power", user_id: str | None = None) -> str:
uid = user_id or str(uuid.uuid4())
now = int(time.time())
payload = {
"sub": uid,
"email": f"{uid[:8]}@example.com",
"tier": tier,
"exp": now + 3600,
"iat": now,
}
return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
def _auth(tier: str = "power") -> dict[str, str]:
return {"Authorization": f"Bearer {_make_jwt(tier)}"}
def _fresh_manifest(
plugin_id: str | None = None,
category: str = "productivity",
price_cents: int = 0,
permissions: list[str] | None = None,
) -> PluginManifest:
pid = plugin_id or f"plugin-{uuid.uuid4().hex[:8]}"
return PluginManifest(
id=pid,
name=f"Plugin {pid}",
description=f"Description for {pid}",
version="1.0.0",
author="test-author",
permissions=permissions or ["read:tasks"],
category=category,
price_cents=price_cents,
)
# ---------------------------------------------------------------------------
# PluginRegistry
# ---------------------------------------------------------------------------
class TestPluginRegistry:
"""Each test uses a fresh PluginRegistry instance to avoid catalog pollution."""
@pytest.fixture
def reg(self) -> PluginRegistry:
return PluginRegistry()
@pytest.mark.asyncio
async def test_seed_plugins_are_approved(self, reg: PluginRegistry) -> None:
result = await reg.list_plugins()
assert result.total == 3
assert all(p.id.startswith("plugin-") for p in result.plugins)
@pytest.mark.asyncio
async def test_list_approved_only(self, reg: PluginRegistry) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(manifest, "plugins/key.zip")
result = await reg.list_plugins()
ids = [p.id for p in result.plugins]
assert manifest.id not in ids # still pending
@pytest.mark.asyncio
async def test_list_filter_by_category(self, reg: PluginRegistry) -> None:
result = await reg.list_plugins(category="communication")
assert result.total == 1
assert result.plugins[0].id == "plugin-slack-notify"
@pytest.mark.asyncio
async def test_list_filter_by_query(self, reg: PluginRegistry) -> None:
result = await reg.list_plugins(query="time")
assert result.total == 1
assert result.plugins[0].id == "plugin-time-tracker"
@pytest.mark.asyncio
async def test_list_sort_by_installs(self, reg: PluginRegistry) -> None:
await reg.record_install("plugin-slack-notify")
await reg.record_install("plugin-slack-notify")
result = await reg.list_plugins(sort="installs")
assert result.plugins[0].id == "plugin-slack-notify"
@pytest.mark.asyncio
async def test_get_plugin_found(self, reg: PluginRegistry) -> None:
entry = await reg.get_plugin("plugin-github-sync")
assert entry is not None
assert entry["manifest"].id == "plugin-github-sync"
assert "install_count" in entry
@pytest.mark.asyncio
async def test_get_plugin_not_found(self, reg: PluginRegistry) -> None:
entry = await reg.get_plugin("no-such-plugin")
assert entry is None
@pytest.mark.asyncio
async def test_submit_sets_pending(self, reg: PluginRegistry) -> None:
manifest = _fresh_manifest()
plugin_id = await reg.submit_plugin(manifest, "key.zip")
assert plugin_id == manifest.id
assert reg._catalog[plugin_id]["status"] == "pending_review"
@pytest.mark.asyncio
async def test_approve_makes_visible(self, reg: PluginRegistry) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(manifest, "key.zip")
await reg.approve_plugin(manifest.id)
result = await reg.list_plugins()
assert manifest.id in [p.id for p in result.plugins]
@pytest.mark.asyncio
async def test_reject_stores_reason(self, reg: PluginRegistry) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(manifest, "key.zip")
await reg.reject_plugin(manifest.id, reason="Unsafe permissions")
assert reg._catalog[manifest.id]["status"] == "rejected"
assert reg._catalog[manifest.id]["rejection_reason"] == "Unsafe permissions"
result = await reg.list_plugins()
assert manifest.id not in [p.id for p in result.plugins]
@pytest.mark.asyncio
async def test_approve_unknown_raises_key_error(self, reg: PluginRegistry) -> None:
with pytest.raises(KeyError):
await reg.approve_plugin("ghost-plugin")
@pytest.mark.asyncio
async def test_record_install_increments_count(self, reg: PluginRegistry) -> None:
await reg.record_install("plugin-github-sync")
entry = await reg.get_plugin("plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_record_uninstall_decrements_count(self, reg: PluginRegistry) -> None:
await reg.record_install("plugin-github-sync")
await reg.record_install("plugin-github-sync")
await reg.record_uninstall("plugin-github-sync")
entry = await reg.get_plugin("plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_record_uninstall_floors_at_zero(self, reg: PluginRegistry) -> None:
await reg.record_uninstall("plugin-github-sync") # already 0
entry = await reg.get_plugin("plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 0
# ---------------------------------------------------------------------------
# ReviewQueue
# ---------------------------------------------------------------------------
class TestReviewQueue:
@pytest.fixture
def reg(self) -> PluginRegistry:
return PluginRegistry()
@pytest.fixture
def queue(self, reg: PluginRegistry) -> ReviewQueue:
# Patch the 'registry' name as bound inside plugin_review.py
with patch("app.marketplace.plugin_review.registry", reg):
yield ReviewQueue()
@pytest.mark.asyncio
async def test_get_pending_returns_submitted_plugins(
self, reg: PluginRegistry, queue: ReviewQueue
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(manifest, "key.zip")
pending = await queue.get_pending()
assert any(p["plugin_id"] == manifest.id for p in pending)
@pytest.mark.asyncio
async def test_submit_review_approved(
self, reg: PluginRegistry, queue: ReviewQueue
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(manifest, "key.zip")
await queue.submit_review(manifest.id, "reviewer-1", "approved", "Looks good")
assert reg._catalog[manifest.id]["status"] == "approved"
@pytest.mark.asyncio
async def test_submit_review_rejected(
self, reg: PluginRegistry, queue: ReviewQueue
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(manifest, "key.zip")
await queue.submit_review(manifest.id, "reviewer-1", "rejected", "Bad permissions")
assert reg._catalog[manifest.id]["status"] == "rejected"
def test_validate_manifest_ok(self) -> None:
manifest = _fresh_manifest(permissions=["read:tasks", "write:notes"])
validate_manifest(manifest) # should not raise
def test_validate_manifest_unknown_permission(self) -> None:
manifest = _fresh_manifest(permissions=["read:tasks", "read:secrets"])
with pytest.raises(ValueError, match="Unknown permission"):
validate_manifest(manifest)
def test_validate_manifest_invalid_id_format(self) -> None:
manifest = _fresh_manifest(plugin_id="Plugin_ID_Invalid")
with pytest.raises(ValueError, match="Invalid plugin id format"):
validate_manifest(manifest)
def test_validate_manifest_id_with_uppercase(self) -> None:
manifest = _fresh_manifest(plugin_id="UpperCase")
with pytest.raises(ValueError, match="Invalid plugin id format"):
validate_manifest(manifest)
# ---------------------------------------------------------------------------
# RevenueShare
# ---------------------------------------------------------------------------
class TestRevenueShare:
@pytest.fixture
def reg(self) -> PluginRegistry:
return PluginRegistry()
@pytest.fixture
def rs(self, reg: PluginRegistry) -> RevenueShare:
# Patch the 'registry' name as bound inside revenue_share.py
with patch("app.marketplace.revenue_share.registry", reg):
yield RevenueShare()
@pytest.mark.asyncio
async def test_record_install_free_plugin(
self, reg: PluginRegistry, rs: RevenueShare
) -> None:
await rs.record_install("plugin-github-sync", "user-1", amount_cents=0)
assert len(rs._events) == 1
assert rs._events[0]["developer_share_cents"] == 0
@pytest.mark.asyncio
async def test_record_install_paid_plugin_no_stripe(
self, reg: PluginRegistry, rs: RevenueShare
) -> None:
# No STRIPE_SECRET_KEY configured in test env — should not crash
await rs.record_install("plugin-slack-notify", "user-2", amount_cents=499)
assert len(rs._events) == 1
assert rs._events[0]["amount_cents"] == 499
assert rs._events[0]["developer_share_cents"] == int(499 * 0.70)
@pytest.mark.asyncio
async def test_record_install_increments_registry_count(
self, reg: PluginRegistry, rs: RevenueShare
) -> None:
await rs.record_install("plugin-github-sync", "user-1", amount_cents=0)
entry = await reg.get_plugin("plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_get_earnings_empty(
self, reg: PluginRegistry, rs: RevenueShare
) -> None:
result = await rs.get_earnings("unknown-dev")
assert result["total_installs"] == 0
assert result["total_revenue_cents"] == 0
assert result["developer_share_cents"] == 0
@pytest.mark.asyncio
async def test_get_earnings_aggregates(
self, reg: PluginRegistry, rs: RevenueShare
) -> None:
# "Adiuva" is the author of the seeded plugins
await rs.record_install("plugin-slack-notify", "u1", amount_cents=499)
await rs.record_install("plugin-slack-notify", "u2", amount_cents=499)
result = await rs.get_earnings("Adiuva")
assert result["total_installs"] == 2
assert result["total_revenue_cents"] == 998
assert result["developer_share_cents"] == int(499 * 0.70) * 2
# ---------------------------------------------------------------------------
# Route integration tests
# ---------------------------------------------------------------------------
class TestPluginRoutes:
def test_list_plugins_requires_power_tier(self) -> None:
with TestClient(app) as client:
resp = client.get("/api/v1/plugins", headers=_auth("free"))
assert resp.status_code == 403
def test_list_plugins_pro_tier_blocked(self) -> None:
with TestClient(app) as client:
resp = client.get("/api/v1/plugins", headers=_auth("pro"))
assert resp.status_code == 403
def test_list_plugins_power_tier_ok(self) -> None:
with TestClient(app) as client:
resp = client.get("/api/v1/plugins", headers=_auth("power"))
assert resp.status_code == 200
data = resp.json()
assert "plugins" in data
assert data["total"] >= 3
def test_list_plugins_team_tier_ok(self) -> None:
with TestClient(app) as client:
resp = client.get("/api/v1/plugins", headers=_auth("team"))
assert resp.status_code == 200
def test_get_plugin_found(self) -> None:
with TestClient(app) as client:
resp = client.get("/api/v1/plugins/plugin-github-sync", headers=_auth())
assert resp.status_code == 200
data = resp.json()
assert data["plugin"]["id"] == "plugin-github-sync"
assert "install_count" in data
def test_get_plugin_not_found(self) -> None:
with TestClient(app) as client:
resp = client.get("/api/v1/plugins/no-such-plugin", headers=_auth())
assert resp.status_code == 404
def test_install_plugin_free(self) -> None:
with TestClient(app) as client:
resp = client.post(
"/api/v1/plugins/plugin-github-sync/install",
json={"plugin_id": "plugin-github-sync"},
headers=_auth(),
)
assert resp.status_code == 200
data = resp.json()
assert data["ok"] is True
assert "download_url" in data
def test_install_plugin_not_found(self) -> None:
with TestClient(app) as client:
resp = client.post(
"/api/v1/plugins/ghost/install",
json={"plugin_id": "ghost"},
headers=_auth(),
)
assert resp.status_code == 404
def test_uninstall_plugin_ok(self) -> None:
with TestClient(app) as client:
resp = client.delete(
"/api/v1/plugins/plugin-github-sync/install",
headers=_auth(),
)
assert resp.status_code == 200
assert resp.json()["ok"] is True
def test_install_requires_power_tier(self) -> None:
with TestClient(app) as client:
resp = client.post(
"/api/v1/plugins/plugin-github-sync/install",
json={"plugin_id": "plugin-github-sync"},
headers=_auth("free"),
)
assert resp.status_code == 403