Step 12 - completed

This commit is contained in:
2026-03-03 14:53:34 +01:00
parent 5d485b3665
commit d0b303e745
13 changed files with 950 additions and 487 deletions

View File

@@ -439,7 +439,7 @@ adiuva-api/
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
### Step 12 — Database (auth/billing/marketplace only)
- [ ] PostgreSQL schema via Alembic:
- [x] PostgreSQL schema via Alembic:
- `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at`
- `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at`
- `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at`
@@ -449,8 +449,8 @@ adiuva-api/
- `plugin_installations`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `installed_at`
- `plugin_reviews`: `id UUID PK`, `plugin_id FK`, `reviewer_id FK`, `decision`, `notes`, `reviewed_at`
- `revenue_events`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `amount_cents`, `developer_share_cents`, `stripe_transfer_id`, `created_at`
- [ ] Initial Alembic migration
- [ ] SQLAlchemy models in `app/models.py`
- [x] Initial Alembic migration
- [x] SQLAlchemy models in `app/models.py`
- **Outcome:** Auth, billing, storage metadata, and marketplace persistence. Zero user data in plaintext.
### Step 13 — Testing & deployment

View File

@@ -0,0 +1,92 @@
"""Seed approved plugins: GitHub Sync, Slack Notifier, Time Tracker.
Revision ID: 002
Revises: 001
Create Date: 2026-03-03
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "002"
down_revision: Union[str, None] = "001"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
_SEED_PLUGINS = [
{
"id": "plugin-github-sync",
"name": "GitHub Sync",
"description": "Sync tasks with GitHub Issues and pull requests.",
"version": "1.0.0",
"author_name": "Adiuva",
"category": "productivity",
"price_cents": 0,
"permissions": json.dumps(["read:tasks", "write:tasks"]),
"status": "approved",
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
{
"id": "plugin-slack-notify",
"name": "Slack Notifier",
"description": "Post task and checkpoint updates to Slack channels.",
"version": "1.2.0",
"author_name": "Adiuva",
"category": "communication",
"price_cents": 499,
"permissions": json.dumps(["read:tasks", "read:checkpoints"]),
"status": "approved",
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
{
"id": "plugin-time-tracker",
"name": "Time Tracker",
"description": "Track time spent on tasks with automatic reporting.",
"version": "0.9.1",
"author_name": "Third Party",
"category": "productivity",
"price_cents": 999,
"permissions": json.dumps(["read:tasks", "write:tasks"]),
"status": "approved",
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
]
def upgrade() -> None:
plugins = sa.table(
"plugins",
sa.column("id", sa.String),
sa.column("name", sa.String),
sa.column("description", sa.Text),
sa.column("version", sa.String),
sa.column("author_name", sa.String),
sa.column("category", sa.String),
sa.column("price_cents", sa.Integer),
sa.column("permissions", sa.Text),
sa.column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status")),
sa.column("s3_package_key", sa.String),
sa.column("install_count", sa.Integer),
sa.column("avg_rating", sa.Float),
)
op.bulk_insert(plugins, _SEED_PLUGINS)
def downgrade() -> None:
op.execute(
"DELETE FROM plugins WHERE id IN ("
"'plugin-github-sync', 'plugin-slack-notify', 'plugin-time-tracker'"
")"
)

View File

@@ -1,7 +1,7 @@
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
Blobs are stored in S3 via BlobStore. Backup metadata is kept in an
in-memory dict until Step 12 migrates it to PostgreSQL (backup_metadata table).
Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the
PostgreSQL ``backup_metadata`` table.
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
treating "history" as a ``{backup_id}`` path parameter.
@@ -9,14 +9,17 @@ treating "history" as a ``{backup_id}`` path parameter.
from __future__ import annotations
import time
import uuid
from email.utils import parsedate_to_datetime
from typing import Any
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.db import get_session
from app.models import BackupMetadata as BackupMetadataModel
from app.schemas import BackupMetadata, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
@@ -25,14 +28,25 @@ router = APIRouter(prefix="/backup", tags=["backup"])
_blob_store = BlobStore()
# In-memory backup metadata — replaced by PostgreSQL backup_metadata table in Step 12
_backups: dict[str, list[dict[str, Any]]] = {} # user_id → list of backup records
async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int:
"""Return total backup bytes stored by *user_id*."""
result = await db.execute(
select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where(
BackupMetadataModel.user_id == user_id
)
)
return int(result.scalar_one())
def _check_backup_quota(user_id: str, size_bytes: int) -> None:
async def _check_backup_quota(
user: UserProfile, size_bytes: int, db: AsyncSession
) -> None:
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
current = sum(b["size_bytes"] for b in _backups.get(user_id, []))
tier_manager.enforce_backup_quota(user_id, current_bytes=current, additional_bytes=size_bytes)
current = await _current_backup_bytes(user.id, db)
tier_manager.enforce_backup_quota(
user.tier, current_bytes=current, additional_bytes=size_bytes
)
@router.put("")
@@ -42,6 +56,7 @@ async def upload_backup(
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Upload an E2E-encrypted backup blob.
@@ -49,24 +64,23 @@ async def upload_backup(
"""
blob = await request.body()
reject_if_tampered(blob, x_backup_checksum)
_check_backup_quota(current_user.id, len(blob))
await _check_backup_quota(current_user, len(blob), db)
s3_key = await _blob_store.upload(
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
)
backup_record: dict[str, Any] = {
"id": str(x_backup_timestamp),
"s3_key": s3_key,
"version": x_backup_version,
"timestamp": x_backup_timestamp,
"checksum": x_backup_checksum,
"size_bytes": len(blob),
}
user_backups = _backups.setdefault(current_user.id, [])
user_backups.append(backup_record)
user_backups.sort(key=lambda b: b["timestamp"], reverse=True)
row = BackupMetadataModel(
id=str(uuid.uuid4()),
user_id=current_user.id,
s3_key=s3_key,
version=x_backup_version,
timestamp=x_backup_timestamp,
checksum=x_backup_checksum,
size_bytes=len(blob),
)
db.add(row)
await db.commit()
return {"ok": True}
@@ -74,16 +88,23 @@ async def upload_backup(
@router.get("/history", response_model=list[BackupMetadata])
async def backup_history(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[BackupMetadata]:
"""Return backup metadata records for the authenticated user (no blob bytes)."""
result = await db.execute(
select(BackupMetadataModel)
.where(BackupMetadataModel.user_id == current_user.id)
.order_by(BackupMetadataModel.timestamp.desc())
)
rows = result.scalars().all()
return [
BackupMetadata(
version=b["version"],
timestamp=b["timestamp"],
checksum=b["checksum"],
chunk_count=1, # single-chunk uploads for now — TODO(Step12): track real count
version=r.version,
timestamp=r.timestamp,
checksum=r.checksum,
chunk_count=1,
)
for b in _backups.get(current_user.id, [])
for r in rows
]
@@ -91,32 +112,37 @@ async def backup_history(
async def download_backup(
request: Request,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> Response:
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
user_backups = _backups.get(current_user.id, [])
if not user_backups:
result = await db.execute(
select(BackupMetadataModel)
.where(BackupMetadataModel.user_id == current_user.id)
.order_by(BackupMetadataModel.timestamp.desc())
.limit(1)
)
latest = result.scalar_one_or_none()
if latest is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
latest = user_backups[0]
ims_header = request.headers.get("If-Modified-Since")
if ims_header:
try:
ims_dt = parsedate_to_datetime(ims_header)
ims_ms = int(ims_dt.timestamp() * 1000)
if latest["timestamp"] <= ims_ms:
if latest.timestamp <= ims_ms:
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
except Exception:
pass # malformed header — ignore and serve the blob
blob = await _blob_store.download(current_user.id, latest["s3_key"])
blob = await _blob_store.download(current_user.id, latest.s3_key)
return Response(
content=blob,
media_type="application/octet-stream",
headers={
"X-Backup-Version": str(latest["version"]),
"X-Backup-Timestamp": str(latest["timestamp"]),
"X-Checksum": latest["checksum"],
"X-Backup-Version": str(latest.version),
"X-Backup-Timestamp": str(latest.timestamp),
"X-Checksum": latest.checksum,
},
)
@@ -125,14 +151,21 @@ async def download_backup(
async def delete_backup(
backup_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a specific backup by ID."""
user_backups = _backups.get(current_user.id, [])
target = next((b for b in user_backups if b["id"] == backup_id), None)
result = await db.execute(
select(BackupMetadataModel).where(
BackupMetadataModel.id == backup_id,
BackupMetadataModel.user_id == current_user.id,
)
)
target = result.scalar_one_or_none()
if target is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
await _blob_store.delete(current_user.id, target["s3_key"])
_backups[current_user.id] = [b for b in user_backups if b["id"] != backup_id]
await _blob_store.delete(current_user.id, target.s3_key)
await db.delete(target)
await db.commit()
return {"ok": True}

View File

@@ -1,8 +1,7 @@
"""Plugins routes: browse and install plugins from the marketplace.
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes introduced
in Step 10. Step 12 will swap those services' in-memory stores for
PostgreSQL persistence.
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that
persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables.
"""
from __future__ import annotations
@@ -11,10 +10,14 @@ from typing import Any, Literal
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.db import get_session
from app.marketplace.plugin_registry import registry
from app.marketplace.revenue_share import revenue_share
from app.models import PluginInstallation, PluginReview as PluginReviewModel
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
router = APIRouter(prefix="/plugins", tags=["plugins"])
@@ -36,7 +39,7 @@ def _require_plugin_tier(user: UserProfile) -> None:
class _PluginDetail(BaseModel):
plugin: PluginManifest
install_count: int
ratings: list[Any] # Step 12 populates from plugin_reviews table
ratings: list[Any]
# ── Routes ────────────────────────────────────────────────────────────
@@ -48,26 +51,44 @@ async def list_plugins(
page: int = Query(default=1, ge=1),
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> PluginListResponse:
"""Browse the plugin marketplace. Requires Power tier or above."""
_require_plugin_tier(current_user)
return await registry.list_plugins(category=category, query=q, page=page, sort=sort)
return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort)
@router.get("/{plugin_id}", response_model=_PluginDetail)
async def get_plugin(
plugin_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _PluginDetail:
"""Get full plugin details including install count. Requires Power tier or above."""
_require_plugin_tier(current_user)
entry = await registry.get_plugin(plugin_id)
entry = await registry.get_plugin(db, plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
# Fetch review ratings for this plugin
review_result = await db.execute(
select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id)
)
reviews = review_result.scalars().all()
ratings = [
{
"reviewer_id": r.reviewer_id,
"decision": r.decision,
"notes": r.notes,
"reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None,
}
for r in reviews
]
return _PluginDetail(
plugin=entry["manifest"],
install_count=entry["install_count"],
ratings=[], # Step 12 populates from plugin_reviews table
ratings=ratings,
)
@@ -76,17 +97,27 @@ async def install_plugin(
plugin_id: str,
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
Requires Power tier or above.
"""
_require_plugin_tier(current_user)
entry = await registry.get_plugin(plugin_id)
entry = await registry.get_plugin(db, plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
# Record the installation in plugin_installations
installation = PluginInstallation(
plugin_id=plugin_id,
user_id=current_user.id,
)
db.add(installation)
await db.flush()
await revenue_share.record_install(
db,
plugin_id=plugin_id,
user_id=current_user.id,
amount_cents=entry["manifest"].price_cents,
@@ -100,7 +131,18 @@ async def install_plugin(
async def uninstall_plugin(
plugin_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Unregister a plugin installation."""
await registry.record_uninstall(plugin_id)
result = await db.execute(
select(PluginInstallation).where(
PluginInstallation.plugin_id == plugin_id,
PluginInstallation.user_id == current_user.id,
)
)
installation = result.scalar_one_or_none()
if installation is not None:
await db.delete(installation)
await db.commit()
await registry.record_uninstall(db, plugin_id)
return {"ok": True}

View File

@@ -1,20 +1,23 @@
"""Storage routes: CRUD for E2E-encrypted cloud records.
Blobs are stored in S3 via BlobStore. Record metadata is kept in an
in-memory dict until Step 12 migrates it to PostgreSQL (storage_records table).
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
PostgreSQL ``storage_records`` table.
"""
from __future__ import annotations
import time
import uuid
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.db import get_session
from app.models import StorageRecord
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
@@ -23,9 +26,6 @@ router = APIRouter(prefix="/storage", tags=["storage"])
_blob_store = BlobStore()
# In-memory record metadata — replaced by PostgreSQL storage_records table in Step 12
_records: dict[str, dict[str, Any]] = {}
# ── Local response schemas ─────────────────────────────────────────────
@@ -44,17 +44,34 @@ class _RecordMeta(BaseModel):
# ── Helpers ────────────────────────────────────────────────────────────
def _check_quota(user_id: str, additional_bytes: int) -> None:
"""Raise HTTP 402 if adding ``additional_bytes`` would exceed the tier limit."""
current = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id)
tier_manager.enforce_quota(user_id, current_bytes=current, additional_bytes=additional_bytes)
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
"""Return total bytes stored by *user_id*."""
result = await db.execute(
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
StorageRecord.user_id == user_id
)
)
return int(result.scalar_one())
def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]:
"""Look up a record and verify ownership. Always returns 404 on mismatch
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
current = await _current_usage_bytes(user.id, db)
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
async def _get_record_for_user(
record_id: str, user_id: str, db: AsyncSession
) -> StorageRecord:
"""Look up a record and verify ownership. Returns 404 on mismatch
to prevent user enumeration attacks."""
record = _records.get(record_id)
if record is None or record["user_id"] != user_id:
result = await db.execute(
select(StorageRecord).where(
StorageRecord.id == record_id, StorageRecord.user_id == user_id
)
)
record = result.scalar_one_or_none()
if record is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
return record
@@ -65,30 +82,32 @@ def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]:
async def create_record(
body: StorageRecordCreate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _CreateResponse:
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
reject_if_tampered(body.blob, body.checksum)
_check_quota(current_user.id, len(body.blob))
await _check_quota(current_user, len(body.blob), db)
record_id = str(uuid.uuid4())
now = int(time.time() * 1000)
s3_key = await _blob_store.upload(
current_user.id, body.table, record_id, body.blob, body.checksum
)
_records[record_id] = {
"id": record_id,
"user_id": current_user.id,
"table": body.table,
"s3_key": s3_key,
"checksum": body.checksum,
"size_bytes": len(body.blob),
"created_at": now,
"updated_at": now,
}
record = StorageRecord(
id=record_id,
user_id=current_user.id,
table_name=body.table,
s3_key=s3_key,
checksum=body.checksum,
size_bytes=len(body.blob),
)
db.add(record)
await db.commit()
await db.refresh(record)
return _CreateResponse(id=record_id, created_at=now)
created_at_ms = int(record.created_at.timestamp() * 1000)
return _CreateResponse(id=record_id, created_at=created_at_ms)
@router.get("/records", response_model=list[_RecordMeta])
@@ -97,23 +116,26 @@ async def list_records(
page: int = Query(default=1, ge=1),
limit: int = Query(default=50, ge=1, le=200),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[_RecordMeta]:
"""List record metadata for the authenticated user. Blob bytes are never returned."""
all_records = [
r for r in _records.values()
if r["user_id"] == current_user.id and (table is None or r["table"] == table)
]
start = (page - 1) * limit
page_records = all_records[start : start + limit]
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
if table is not None:
query = query.where(StorageRecord.table_name == table)
query = query.offset((page - 1) * limit).limit(limit)
result = await db.execute(query)
rows = result.scalars().all()
return [
_RecordMeta(
id=r["id"],
table=r["table"],
checksum=r["checksum"],
created_at=r["created_at"],
updated_at=r["updated_at"],
id=r.id,
table=r.table_name,
checksum=r.checksum,
created_at=int(r.created_at.timestamp() * 1000),
updated_at=int(r.updated_at.timestamp() * 1000),
)
for r in page_records
for r in rows
]
@@ -121,14 +143,15 @@ async def list_records(
async def download_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> Response:
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
record = _get_record_for_user(record_id, current_user.id)
blob = await _blob_store.download(current_user.id, record["s3_key"])
record = await _get_record_for_user(record_id, current_user.id, db)
blob = await _blob_store.download(current_user.id, record.s3_key)
return Response(
content=blob,
media_type="application/octet-stream",
headers={"X-Checksum": record["checksum"]},
headers={"X-Checksum": record.checksum},
)
@@ -137,23 +160,24 @@ async def update_record(
record_id: str,
body: StorageRecordUpdate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Replace the blob for an existing record. Verifies checksum before storing."""
record = _get_record_for_user(record_id, current_user.id)
record = await _get_record_for_user(record_id, current_user.id, db)
reject_if_tampered(body.blob, body.checksum)
delta = len(body.blob) - record["size_bytes"]
delta = len(body.blob) - record.size_bytes
if delta > 0:
_check_quota(current_user.id, delta)
await _check_quota(current_user, delta, db)
s3_key = await _blob_store.upload(
current_user.id, record["table"], record_id, body.blob, body.checksum
current_user.id, record.table_name, record_id, body.blob, body.checksum
)
record["s3_key"] = s3_key
record["checksum"] = body.checksum
record["size_bytes"] = len(body.blob)
record["updated_at"] = int(time.time() * 1000)
record.s3_key = s3_key
record.checksum = body.checksum
record.size_bytes = len(body.blob)
await db.commit()
return {"ok": True}
@@ -162,9 +186,11 @@ async def update_record(
async def delete_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a record and its S3 blob."""
record = _get_record_for_user(record_id, current_user.id)
await _blob_store.delete(current_user.id, record["s3_key"])
del _records[record_id]
record = await _get_record_for_user(record_id, current_user.id, db)
await _blob_store.delete(current_user.id, record.s3_key)
await db.delete(record)
await db.commit()
return {"ok": True}

View File

@@ -1,8 +1,7 @@
"""Plugin catalog registry.
"""Plugin catalog registry backed by PostgreSQL.
Maintains the authoritative list of plugins, their review status, and
aggregate install counts. Storage is in-memory until Step 12 migrates to
the ``plugins`` PostgreSQL table.
aggregate install counts. All data is persisted in the ``plugins`` table.
Module-level singleton::
@@ -11,144 +10,103 @@ Module-level singleton::
from __future__ import annotations
import copy
import time
import uuid
import json
from typing import Any, Literal
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import Plugin
from app.schemas import PluginListResponse, PluginManifest
# ── Pre-seeded approved plugins (mirrors the Step 8 stub catalog) ─────
_SEED_PLUGINS: list[dict[str, Any]] = [
{
"manifest": PluginManifest(
id="plugin-github-sync",
name="GitHub Sync",
description="Sync tasks with GitHub Issues and pull requests.",
version="1.0.0",
author="Adiuva",
permissions=["read:tasks", "write:tasks"],
category="productivity",
price_cents=0,
),
"status": "approved",
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
"rejection_reason": None,
"submitted_at": int(time.time()),
},
{
"manifest": PluginManifest(
id="plugin-slack-notify",
name="Slack Notifier",
description="Post task and checkpoint updates to Slack channels.",
version="1.2.0",
author="Adiuva",
permissions=["read:tasks", "read:checkpoints"],
category="communication",
price_cents=499,
),
"status": "approved",
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
"rejection_reason": None,
"submitted_at": int(time.time()),
},
{
"manifest": PluginManifest(
id="plugin-time-tracker",
name="Time Tracker",
description="Track time spent on tasks with automatic reporting.",
version="0.9.1",
author="Third Party",
permissions=["read:tasks", "write:tasks"],
category="productivity",
price_cents=999,
),
"status": "approved",
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
"install_count": 0,
"avg_rating": 0.0,
"rejection_reason": None,
"submitted_at": int(time.time()),
},
]
_PAGE_SIZE = 20
def _plugin_to_manifest(p: Plugin) -> PluginManifest:
"""Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``."""
try:
permissions = json.loads(p.permissions) if p.permissions else []
except (json.JSONDecodeError, TypeError):
permissions = []
return PluginManifest(
id=p.id,
name=p.name,
description=p.description,
version=p.version,
author=p.author_name,
permissions=permissions,
category=p.category,
price_cents=p.price_cents,
)
class PluginRegistry:
"""In-process plugin catalog.
"""PostgreSQL-backed plugin catalog.
All mutating methods are ``async`` to make the future DB swap transparent
to callers.
All methods accept an ``AsyncSession`` parameter so the calling route
controls the session lifecycle.
"""
def __init__(self) -> None:
# plugin_id → entry dict (deep-copied so each instance is independent)
self._catalog: dict[str, dict[str, Any]] = {
e["manifest"].id: copy.deepcopy(e) for e in _SEED_PLUGINS
}
# ── Queries ──────────────────────────────────────────────────────
async def list_plugins(
self,
db: AsyncSession,
category: str | None = None,
query: str | None = None,
page: int = 1,
sort: Literal["rating", "installs", "newest"] = "newest",
) -> PluginListResponse:
"""Return a page of approved plugins, optionally filtered and sorted."""
entries = [e for e in self._catalog.values() if e["status"] == "approved"]
base = select(Plugin).where(Plugin.status == "approved")
if category:
entries = [e for e in entries if e["manifest"].category == category]
base = base.where(Plugin.category == category)
if query:
q_lower = query.lower()
entries = [
e
for e in entries
if q_lower in e["manifest"].name.lower()
or q_lower in e["manifest"].description.lower()
]
pattern = f"%{query}%"
base = base.where(
Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern)
)
# Count
count_q = select(func.count()).select_from(base.subquery())
total = (await db.execute(count_q)).scalar_one()
# Sort
if sort == "installs":
entries = sorted(entries, key=lambda e: e["install_count"], reverse=True)
base = base.order_by(Plugin.install_count.desc())
elif sort == "rating":
entries = sorted(entries, key=lambda e: e["avg_rating"], reverse=True)
# "newest" = catalog insertion order (dict preserves insertion in Python 3.7+)
base = base.order_by(Plugin.avg_rating.desc())
else: # newest
base = base.order_by(Plugin.created_at.desc())
total = len(entries)
start = (page - 1) * _PAGE_SIZE
page_entries = entries[start : start + _PAGE_SIZE]
base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE)
rows = (await db.execute(base)).scalars().all()
return PluginListResponse(
plugins=[e["manifest"] for e in page_entries],
plugins=[_plugin_to_manifest(r) for r in rows],
total=total,
page=page,
)
async def get_plugin(self, plugin_id: str) -> dict[str, Any] | None:
async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None:
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
entry = self._catalog.get(plugin_id)
if entry is None:
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
p = result.scalar_one_or_none()
if p is None:
return None
return {
"manifest": entry["manifest"],
"status": entry["status"],
"install_count": entry["install_count"],
"avg_rating": entry["avg_rating"],
"manifest": _plugin_to_manifest(p),
"status": p.status,
"install_count": p.install_count,
"avg_rating": p.avg_rating,
}
# ── Mutations ────────────────────────────────────────────────────
async def submit_plugin(
self,
db: AsyncSession,
manifest: PluginManifest,
package_s3_key: str,
) -> str:
@@ -157,54 +115,97 @@ class PluginRegistry:
Returns the plugin_id. If a plugin with the same id already exists
it is overwritten (re-submission after rejection).
"""
plugin_id = manifest.id or str(uuid.uuid4())
self._catalog[plugin_id] = {
"manifest": manifest,
"status": "pending_review",
"s3_package_key": package_s3_key,
"install_count": 0,
"avg_rating": 0.0,
"rejection_reason": None,
"submitted_at": int(time.time()),
}
plugin_id = manifest.id
existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = existing.scalar_one_or_none()
if row is not None:
row.name = manifest.name
row.description = manifest.description
row.version = manifest.version
row.author_name = manifest.author
row.category = manifest.category
row.price_cents = manifest.price_cents
row.permissions = json.dumps(manifest.permissions)
row.status = "pending_review"
row.s3_package_key = package_s3_key
row.rejection_reason = None
else:
row = Plugin(
id=plugin_id,
name=manifest.name,
description=manifest.description,
version=manifest.version,
author_name=manifest.author,
category=manifest.category,
price_cents=manifest.price_cents,
permissions=json.dumps(manifest.permissions),
status="pending_review",
s3_package_key=package_s3_key,
install_count=0,
avg_rating=0.0,
)
db.add(row)
await db.commit()
return plugin_id
async def approve_plugin(self, plugin_id: str) -> None:
async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None:
"""Set *plugin_id* status to ``'approved'``.
Raises ``KeyError`` if the plugin is not found.
"""
if plugin_id not in self._catalog:
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is None:
raise KeyError(f"Plugin not found: {plugin_id}")
self._catalog[plugin_id]["status"] = "approved"
self._catalog[plugin_id]["rejection_reason"] = None
row.status = "approved"
row.rejection_reason = None
await db.commit()
async def reject_plugin(self, plugin_id: str, reason: str) -> None:
async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None:
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
Raises ``KeyError`` if the plugin is not found.
"""
if plugin_id not in self._catalog:
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is None:
raise KeyError(f"Plugin not found: {plugin_id}")
self._catalog[plugin_id]["status"] = "rejected"
self._catalog[plugin_id]["rejection_reason"] = reason
row.status = "rejected"
row.rejection_reason = reason
await db.commit()
async def record_install(self, plugin_id: str) -> None:
async def record_install(self, db: AsyncSession, plugin_id: str) -> None:
"""Increment the install count for *plugin_id* (no-op if not found)."""
if plugin_id in self._catalog:
self._catalog[plugin_id]["install_count"] += 1
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is not None:
row.install_count = row.install_count + 1
await db.commit()
async def record_uninstall(self, plugin_id: str) -> None:
async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None:
"""Decrement the install count for *plugin_id*, floored at 0."""
if plugin_id in self._catalog:
current = self._catalog[plugin_id]["install_count"]
self._catalog[plugin_id]["install_count"] = max(0, current - 1)
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is not None:
row.install_count = max(0, row.install_count - 1)
await db.commit()
# ── Internal helpers used by ReviewQueue ─────────────────────────
def _get_pending_entries(self) -> list[dict[str, Any]]:
"""Return all entries with status='pending_review' (synchronous helper)."""
return [e for e in self._catalog.values() if e["status"] == "pending_review"]
async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]:
"""Return all entries with status='pending_review'."""
result = await db.execute(
select(Plugin).where(Plugin.status == "pending_review")
)
rows = result.scalars().all()
return [
{
"manifest": _plugin_to_manifest(r),
"submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0,
}
for r in rows
]
# Module-level singleton

View File

@@ -1,4 +1,4 @@
"""Plugin review workflow.
"""Plugin review workflow backed by PostgreSQL.
Manages the approval queue for newly submitted plugins and enforces a
security checklist before any plugin is made visible in the marketplace.
@@ -11,10 +11,12 @@ Module-level singleton::
from __future__ import annotations
import re
import time
from typing import Any, Literal
from sqlalchemy.ext.asyncio import AsyncSession
from app.marketplace.plugin_registry import registry
from app.models import PluginReview as PluginReviewModel
from app.schemas import PluginManifest
# ── Security policy ───────────────────────────────────────────────────
@@ -72,20 +74,16 @@ def validate_manifest(manifest: PluginManifest) -> None:
class ReviewQueue:
"""Approval queue for pending plugin submissions.
Delegates status changes to the shared ``PluginRegistry`` singleton so
there is a single source of truth for plugin state.
Delegates status changes to the shared ``PluginRegistry`` singleton.
Review records are persisted in the ``plugin_reviews`` table.
"""
def __init__(self) -> None:
# Completed reviews — Step 12 stores in plugin_reviews table
self._reviews: list[dict[str, Any]] = []
async def get_pending(self) -> list[dict[str, Any]]:
async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]:
"""Return all plugins currently awaiting review.
Each item is ``{plugin_id, manifest, submitted_at}``.
"""
entries = registry._get_pending_entries()
entries = await registry.get_pending_entries(db)
return [
{
"plugin_id": e["manifest"].id,
@@ -97,6 +95,7 @@ class ReviewQueue:
async def submit_review(
self,
db: AsyncSession,
plugin_id: str,
reviewer_id: str,
decision: Literal["approved", "rejected"],
@@ -108,19 +107,18 @@ class ReviewQueue:
``KeyError`` if *plugin_id* is not found in the registry.
"""
if decision == "approved":
await registry.approve_plugin(plugin_id)
await registry.approve_plugin(db, plugin_id)
else:
await registry.reject_plugin(plugin_id, reason=notes)
await registry.reject_plugin(db, plugin_id, reason=notes)
self._reviews.append(
{
"plugin_id": plugin_id,
"reviewer_id": reviewer_id,
"decision": decision,
"notes": notes,
"reviewed_at": int(time.time()),
}
review = PluginReviewModel(
plugin_id=plugin_id,
reviewer_id=reviewer_id,
decision=decision,
notes=notes,
)
db.add(review)
await db.commit()
# Module-level singleton

View File

@@ -1,8 +1,8 @@
"""Revenue share tracking and Stripe Connect payouts.
"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL.
Records every plugin installation as a revenue event and facilitates
70 % / 30 % payouts to developers via Stripe Connect. Storage is
in-memory until Step 12 migrates to the ``revenue_events`` table.
70 % / 30 % payouts to developers via Stripe Connect. Data is persisted
in the ``revenue_events`` table.
Module-level singleton::
@@ -12,13 +12,16 @@ Module-level singleton::
from __future__ import annotations
import logging
import time
from datetime import datetime, timezone
from typing import Any
import stripe as stripe_lib
from sqlalchemy import extract, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
from app.marketplace.plugin_registry import registry
from app.models import Plugin, RevenueEvent
logger = logging.getLogger(__name__)
@@ -35,10 +38,6 @@ class RevenueShare:
is not configured, consistent with the rest of the billing layer.
"""
def __init__(self) -> None:
# Step 12 replaces with revenue_events DB table
self._events: list[dict[str, Any]] = []
# ── Helpers ──────────────────────────────────────────────────────
@staticmethod
@@ -54,6 +53,7 @@ class RevenueShare:
async def record_install(
self,
db: AsyncSession,
plugin_id: str,
user_id: str,
amount_cents: int,
@@ -72,11 +72,12 @@ class RevenueShare:
stripe_transfer_id: str | None = None
if amount_cents > 0 and self._stripe_configured():
plugin_entry = registry._catalog.get(plugin_id)
# Look up the plugin's author Stripe account from the DB
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
plugin_row = result.scalar_one_or_none()
developer_stripe_account: str | None = None
if plugin_entry:
# Step 12: look up developer's Stripe account from DB
# For now, the author field is used as a placeholder key.
if plugin_row and plugin_row.author_id:
# Future: look up user.stripe_connect_account_id
developer_stripe_account = None # no real account yet
if developer_stripe_account:
@@ -103,22 +104,21 @@ class RevenueShare:
plugin_id,
)
self._events.append(
{
"plugin_id": plugin_id,
"user_id": user_id,
"amount_cents": amount_cents,
"developer_share_cents": developer_share_cents,
"stripe_transfer_id": stripe_transfer_id,
"paid_at": None,
"created_at": int(time.time()),
}
event = RevenueEvent(
plugin_id=plugin_id,
user_id=user_id,
amount_cents=amount_cents,
developer_share_cents=developer_share_cents,
stripe_transfer_id=stripe_transfer_id,
)
db.add(event)
await db.commit()
await registry.record_install(plugin_id)
await registry.record_install(db, plugin_id)
async def get_earnings(
self,
db: AsyncSession,
developer_id: str,
period: str | None = None,
) -> dict[str, Any]:
@@ -136,54 +136,81 @@ class RevenueShare:
"developer_share_cents": int,
}
"""
# Find plugin ids belonging to this developer
developer_plugin_ids: set[str] = {
pid
for pid, entry in registry._catalog.items()
if entry["manifest"].author == developer_id
}
# Find plugin ids belonging to this developer (by author_name match)
plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id)
plugin_result = await db.execute(plugin_q)
developer_plugin_ids = [row[0] for row in plugin_result.all()]
events = [e for e in self._events if e["plugin_id"] in developer_plugin_ids]
if not developer_plugin_ids:
return {
"developer_id": developer_id,
"period": period,
"total_installs": 0,
"total_revenue_cents": 0,
"developer_share_cents": 0,
}
query = select(
func.count().label("total_installs"),
func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"),
func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"),
).where(RevenueEvent.plugin_id.in_(developer_plugin_ids))
if period:
# Filter by YYYY-MM prefix of the created_at timestamp
events = [
e
for e in events
if time.strftime("%Y-%m", time.gmtime(e["created_at"])) == period
]
# Filter by YYYY-MM: extract year and month from created_at
try:
year, month = period.split("-")
query = query.where(
extract("year", RevenueEvent.created_at) == int(year),
extract("month", RevenueEvent.created_at) == int(month),
)
except ValueError:
pass # invalid period format — return all
result = await db.execute(query)
row = result.one()
return {
"developer_id": developer_id,
"period": period,
"total_installs": len(events),
"total_revenue_cents": sum(e["amount_cents"] for e in events),
"developer_share_cents": sum(e["developer_share_cents"] for e in events),
"total_installs": row.total_installs,
"total_revenue_cents": row.total_revenue,
"developer_share_cents": row.dev_share,
}
async def payout_developer(self, plugin_id: str, period: str) -> None:
async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None:
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
Marks processed events with ``paid_at`` timestamp.
Stubs gracefully when Stripe is not configured.
"""
unpaid = [
e
for e in self._events
if e["plugin_id"] == plugin_id
and e["paid_at"] is None
and time.strftime("%Y-%m", time.gmtime(e["created_at"])) == period
]
try:
year, month = period.split("-")
year_int, month_int = int(year), int(month)
except ValueError:
logger.warning("Invalid period format: %s", period)
return
total_dev_share = sum(e["developer_share_cents"] for e in unpaid)
result = await db.execute(
select(RevenueEvent).where(
RevenueEvent.plugin_id == plugin_id,
RevenueEvent.paid_at.is_(None),
extract("year", RevenueEvent.created_at) == year_int,
extract("month", RevenueEvent.created_at) == month_int,
)
)
unpaid = list(result.scalars().all())
total_dev_share = sum(e.developer_share_cents for e in unpaid)
if total_dev_share <= 0 or not unpaid:
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
return
if self._stripe_configured():
plugin_entry = registry._catalog.get(plugin_id)
developer_stripe_account: str | None = None # Step 12: fetch from DB
if plugin_entry and developer_stripe_account:
plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
plugin_row = plugin_result.scalar_one_or_none()
developer_stripe_account: str | None = None # Future: fetch from DB
if plugin_row and developer_stripe_account:
try:
s = self._stripe()
s.Transfer.create(
@@ -196,9 +223,10 @@ class RevenueShare:
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
return
paid_ts = int(time.time())
paid_ts = datetime.now(timezone.utc)
for event in unpaid:
event["paid_at"] = paid_ts
event.paid_at = paid_ts
await db.commit()
# Module-level singleton

View File

@@ -32,9 +32,9 @@ from sqlalchemy import (
String,
Text,
UniqueConstraint,
Uuid,
func,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db import Base
@@ -64,7 +64,7 @@ class User(Base):
__tablename__ = "users"
id: Mapped[str] = mapped_column(
UUID(as_uuid=False), primary_key=True, default=_uuid
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
@@ -89,10 +89,10 @@ class RefreshToken(Base):
__tablename__ = "refresh_tokens"
id: Mapped[str] = mapped_column(
UUID(as_uuid=False), primary_key=True, default=_uuid
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
@@ -107,10 +107,10 @@ class Subscription(Base):
__tablename__ = "subscriptions"
id: Mapped[str] = mapped_column(
UUID(as_uuid=False), primary_key=True, default=_uuid
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, unique=True, index=True
)
stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
@@ -128,10 +128,10 @@ class StorageRecord(Base):
__tablename__ = "storage_records"
id: Mapped[str] = mapped_column(
UUID(as_uuid=False), primary_key=True, default=_uuid
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
@@ -149,10 +149,10 @@ class BackupMetadata(Base):
__tablename__ = "backup_metadata"
id: Mapped[str] = mapped_column(
UUID(as_uuid=False), primary_key=True, default=_uuid
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
version: Mapped[int] = mapped_column(Integer, nullable=False)
@@ -173,7 +173,7 @@ class Plugin(Base):
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
# nullable until developer account system is built
author_id: Mapped[str | None] = mapped_column(
UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
@@ -207,13 +207,13 @@ class PluginInstallation(Base):
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
id: Mapped[str] = mapped_column(
UUID(as_uuid=False), primary_key=True, default=_uuid
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
installed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
@@ -226,13 +226,13 @@ class PluginReview(Base):
__tablename__ = "plugin_reviews"
id: Mapped[str] = mapped_column(
UUID(as_uuid=False), primary_key=True, default=_uuid
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
reviewer_id: Mapped[str | None] = mapped_column(
UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
@@ -250,13 +250,13 @@ class RevenueEvent(Base):
__tablename__ = "revenue_events"
id: Mapped[str] = mapped_column(
UUID(as_uuid=False), primary_key=True, default=_uuid
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)

View File

@@ -15,8 +15,10 @@ bcrypt>=4.2.0
python-dotenv>=1.0.0
httpx>=0.28.0
websockets>=14.0
psycopg2-binary>=2.9.0
pytest>=8.0.0
pytest-asyncio>=0.24.0
aiosqlite>=0.20.0
moto[s3]>=5.0.0
pinecone>=5.0.0
qdrant-client>=1.7.0

208
tests/conftest.py Normal file
View File

@@ -0,0 +1,208 @@
"""Shared test fixtures for database-backed tests.
Provides an async SQLite in-memory engine that auto-creates all tables,
a per-test session, and a FastAPI ``TestClient`` wired to use it.
"""
from __future__ import annotations
import json
import time
import uuid
from collections.abc import AsyncGenerator, Generator
import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from jose import jwt
from sqlalchemy import StaticPool, event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.config.settings import settings
from app.db import Base, get_session
from app.main import app
from app.models import Plugin, Subscription, User
# ── Fixed test user IDs (one per tier) ───────────────────────────────
TEST_USER_IDS: dict[str, str] = {
"free": "00000000-0000-0000-0000-000000000001",
"pro": "00000000-0000-0000-0000-000000000002",
"power": "00000000-0000-0000-0000-000000000003",
"team": "00000000-0000-0000-0000-000000000004",
}
# ── Async SQLite engine ──────────────────────────────────────────────
_TEST_ENGINE = create_async_engine(
"sqlite+aiosqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
_TestSessionLocal = async_sessionmaker(
_TEST_ENGINE,
expire_on_commit=False,
)
# Enable foreign key enforcement for SQLite (off by default).
@event.listens_for(_TEST_ENGINE.sync_engine, "connect")
def _set_sqlite_pragma(dbapi_conn, _connection_record): # noqa: ANN001
cursor = dbapi_conn.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
# ── Fixtures ─────────────────────────────────────────────────────────
@pytest_asyncio.fixture(autouse=True)
async def _create_tables():
"""Create all tables before each test, seed test users, then drop after."""
async with _TEST_ENGINE.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Seed one User + Subscription per tier so FK constraints and auth work.
async with _TestSessionLocal() as session:
for tier, uid in TEST_USER_IDS.items():
session.add(User(
id=uid,
email=f"{tier}@test.com",
password_hash="$2b$12$fakehashfortesting000000000000000000000000000",
tier=tier,
))
session.add(Subscription(
id=str(uuid.uuid4()),
user_id=uid,
tier=tier,
stripe_subscription_id=f"sub_test_{tier}",
status="active",
))
await session.commit()
yield
async with _TEST_ENGINE.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest_asyncio.fixture
async def db_session() -> AsyncGenerator[AsyncSession, None]:
"""Yield a per-test async DB session."""
async with _TestSessionLocal() as session:
yield session
@pytest.fixture
def client(db_session: AsyncSession) -> Generator[TestClient, None, None]: # noqa: ANN001
"""FastAPI test client with ``get_session`` overridden to use the test DB."""
async def _override_get_session() -> AsyncGenerator[AsyncSession, None]:
yield db_session
app.dependency_overrides[get_session] = _override_get_session
with TestClient(app) as c:
yield c
app.dependency_overrides.pop(get_session, None)
# ── Seed data helpers ────────────────────────────────────────────────
_SEED_PLUGINS = [
Plugin(
id="plugin-github-sync",
name="GitHub Sync",
description="Sync tasks with GitHub Issues and pull requests.",
version="1.0.0",
author_name="Adiuva",
category="productivity",
price_cents=0,
permissions=json.dumps(["read:tasks", "write:tasks"]),
status="approved",
s3_package_key="plugins/plugin-github-sync/1.0.0/package.zip",
install_count=0,
avg_rating=0.0,
),
Plugin(
id="plugin-slack-notify",
name="Slack Notifier",
description="Post task and checkpoint updates to Slack channels.",
version="1.2.0",
author_name="Adiuva",
category="communication",
price_cents=499,
permissions=json.dumps(["read:tasks", "read:checkpoints"]),
status="approved",
s3_package_key="plugins/plugin-slack-notify/1.2.0/package.zip",
install_count=0,
avg_rating=0.0,
),
Plugin(
id="plugin-time-tracker",
name="Time Tracker",
description="Track time spent on tasks with automatic reporting.",
version="0.9.1",
author_name="Third Party",
category="productivity",
price_cents=999,
permissions=json.dumps(["read:tasks", "write:tasks"]),
status="approved",
s3_package_key="plugins/plugin-time-tracker/0.9.1/package.zip",
install_count=0,
avg_rating=0.0,
),
]
@pytest_asyncio.fixture
async def seed_plugins(db_session: AsyncSession) -> list[Plugin]:
"""Insert the 3 default approved plugins and return them."""
plugins = []
for template in _SEED_PLUGINS:
p = Plugin(
id=template.id,
name=template.name,
description=template.description,
version=template.version,
author_name=template.author_name,
category=template.category,
price_cents=template.price_cents,
permissions=template.permissions,
status=template.status,
s3_package_key=template.s3_package_key,
install_count=template.install_count,
avg_rating=template.avg_rating,
)
db_session.add(p)
plugins.append(p)
await db_session.commit()
return plugins
# ── JWT helpers ──────────────────────────────────────────────────────
def make_jwt(
tier: str = "power",
user_id: str | None = None,
email: str | None = None,
) -> str:
"""Create a signed test JWT.
Uses the fixed ``TEST_USER_IDS`` mapping so the auth middleware can
find the corresponding ``Subscription`` row in the test database.
"""
uid = user_id or TEST_USER_IDS.get(tier, str(uuid.uuid4()))
now = int(time.time())
payload = {
"sub": uid,
"email": email or f"{tier}@test.com",
"tier": tier,
"exp": now + 3600,
"iat": now,
}
return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, str]:
"""Return an Authorization header dict for the given tier."""
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}

View File

@@ -18,13 +18,30 @@ from fastapi.testclient import TestClient
from jose import jwt
from app.config.settings import settings
from app.db import get_session
from app.main import app
from app.schemas import ChatResponse
from tests.conftest import TEST_USER_IDS
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Autouse: redirect all DB access to the in-memory SQLite test engine.
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _override_db(db_session):
"""Route all get_session calls to the test SQLite session."""
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
_CHAT_BODY = {
"message": "hello",
"context": {
@@ -74,14 +91,15 @@ class TestAuthMiddleware:
"""Tests exercised via GET /api/v1/auth/me."""
def test_valid_token_returns_profile(self) -> None:
uid = str(uuid.uuid4())
token = _make_jwt(user_id=uid, email="alice@example.com", tier="pro")
# Use the seeded pro user so the subscription lookup returns 'pro'.
uid = TEST_USER_IDS["pro"]
token = _make_jwt(user_id=uid, email="pro@test.com", tier="pro")
with TestClient(app) as client:
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 200
data = resp.json()
assert data["id"] == uid
assert data["email"] == "alice@example.com"
assert data["email"] == "pro@test.com"
assert data["tier"] == "pro"
def test_missing_token_returns_401(self) -> None:

View File

@@ -1,52 +1,34 @@
"""Tests for Step 10: Plugin Marketplace.
"""Tests for Step 10+12: Plugin Marketplace (DB-backed).
Covers:
- PluginRegistry: catalog management, filtering, sorting, install counts
- PluginRegistry: catalog management, filtering, sorting, install counts (PostgreSQL)
- ReviewQueue: pending queue, review decisions, manifest security checklist
- RevenueShare: install event recording, earnings aggregation
- RevenueShare: install event recording, earnings aggregation (PostgreSQL)
- Route integration: tier gate, list/get/install/uninstall via TestClient
"""
from __future__ import annotations
import time
import json
import uuid
import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from jose import jwt
from unittest.mock import patch
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
from app.main import app
from app.marketplace.plugin_registry import PluginRegistry
from app.marketplace.plugin_review import ReviewQueue, validate_manifest
from app.marketplace.revenue_share import RevenueShare
from app.models import Plugin, PluginReview as PluginReviewModel, RevenueEvent
from app.schemas import PluginManifest
from tests.conftest import TEST_USER_IDS, auth_header
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_jwt(tier: str = "power", user_id: str | None = None) -> str:
uid = user_id or str(uuid.uuid4())
now = int(time.time())
payload = {
"sub": uid,
"email": f"{uid[:8]}@example.com",
"tier": tier,
"exp": now + 3600,
"iat": now,
}
return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
def _auth(tier: str = "power") -> dict[str, str]:
return {"Authorization": f"Bearer {_make_jwt(tier)}"}
def _fresh_manifest(
plugin_id: str | None = None,
category: str = "productivity",
@@ -67,118 +49,150 @@ def _fresh_manifest(
# ---------------------------------------------------------------------------
# PluginRegistry
# PluginRegistry (DB-backed)
# ---------------------------------------------------------------------------
class TestPluginRegistry:
"""Each test uses a fresh PluginRegistry instance to avoid catalog pollution."""
"""Each test uses the conftest db_session fixture with a fresh in-memory DB."""
@pytest.fixture
def reg(self) -> PluginRegistry:
return PluginRegistry()
@pytest.mark.asyncio
async def test_seed_plugins_are_approved(self, reg: PluginRegistry) -> None:
result = await reg.list_plugins()
async def test_seed_plugins_are_listed(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
result = await reg.list_plugins(db_session)
assert result.total == 3
assert all(p.id.startswith("plugin-") for p in result.plugins)
@pytest.mark.asyncio
async def test_list_approved_only(self, reg: PluginRegistry) -> None:
async def test_list_approved_only(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(manifest, "plugins/key.zip")
result = await reg.list_plugins()
await reg.submit_plugin(db_session, manifest, "plugins/key.zip")
result = await reg.list_plugins(db_session)
ids = [p.id for p in result.plugins]
assert manifest.id not in ids # still pending
@pytest.mark.asyncio
async def test_list_filter_by_category(self, reg: PluginRegistry) -> None:
result = await reg.list_plugins(category="communication")
async def test_list_filter_by_category(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
result = await reg.list_plugins(db_session, category="communication")
assert result.total == 1
assert result.plugins[0].id == "plugin-slack-notify"
@pytest.mark.asyncio
async def test_list_filter_by_query(self, reg: PluginRegistry) -> None:
result = await reg.list_plugins(query="time")
async def test_list_filter_by_query(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
result = await reg.list_plugins(db_session, query="time")
assert result.total == 1
assert result.plugins[0].id == "plugin-time-tracker"
@pytest.mark.asyncio
async def test_list_sort_by_installs(self, reg: PluginRegistry) -> None:
await reg.record_install("plugin-slack-notify")
await reg.record_install("plugin-slack-notify")
result = await reg.list_plugins(sort="installs")
async def test_list_sort_by_installs(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_install(db_session, "plugin-slack-notify")
await reg.record_install(db_session, "plugin-slack-notify")
result = await reg.list_plugins(db_session, sort="installs")
assert result.plugins[0].id == "plugin-slack-notify"
@pytest.mark.asyncio
async def test_get_plugin_found(self, reg: PluginRegistry) -> None:
entry = await reg.get_plugin("plugin-github-sync")
async def test_get_plugin_found(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["manifest"].id == "plugin-github-sync"
assert "install_count" in entry
@pytest.mark.asyncio
async def test_get_plugin_not_found(self, reg: PluginRegistry) -> None:
entry = await reg.get_plugin("no-such-plugin")
async def test_get_plugin_not_found(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
entry = await reg.get_plugin(db_session, "no-such-plugin")
assert entry is None
@pytest.mark.asyncio
async def test_submit_sets_pending(self, reg: PluginRegistry) -> None:
async def test_submit_sets_pending(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
plugin_id = await reg.submit_plugin(manifest, "key.zip")
plugin_id = await reg.submit_plugin(db_session, manifest, "key.zip")
assert plugin_id == manifest.id
assert reg._catalog[plugin_id]["status"] == "pending_review"
result = await db_session.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one()
assert row.status == "pending_review"
@pytest.mark.asyncio
async def test_approve_makes_visible(self, reg: PluginRegistry) -> None:
async def test_approve_makes_visible(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(manifest, "key.zip")
await reg.approve_plugin(manifest.id)
result = await reg.list_plugins()
await reg.submit_plugin(db_session, manifest, "key.zip")
await reg.approve_plugin(db_session, manifest.id)
result = await reg.list_plugins(db_session)
assert manifest.id in [p.id for p in result.plugins]
@pytest.mark.asyncio
async def test_reject_stores_reason(self, reg: PluginRegistry) -> None:
async def test_reject_stores_reason(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(manifest, "key.zip")
await reg.reject_plugin(manifest.id, reason="Unsafe permissions")
assert reg._catalog[manifest.id]["status"] == "rejected"
assert reg._catalog[manifest.id]["rejection_reason"] == "Unsafe permissions"
result = await reg.list_plugins()
assert manifest.id not in [p.id for p in result.plugins]
await reg.submit_plugin(db_session, manifest, "key.zip")
await reg.reject_plugin(db_session, manifest.id, reason="Unsafe permissions")
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
row = result.scalar_one()
assert row.status == "rejected"
assert row.rejection_reason == "Unsafe permissions"
listed = await reg.list_plugins(db_session)
assert manifest.id not in [p.id for p in listed.plugins]
@pytest.mark.asyncio
async def test_approve_unknown_raises_key_error(self, reg: PluginRegistry) -> None:
async def test_approve_unknown_raises_key_error(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
with pytest.raises(KeyError):
await reg.approve_plugin("ghost-plugin")
await reg.approve_plugin(db_session, "ghost-plugin")
@pytest.mark.asyncio
async def test_record_install_increments_count(self, reg: PluginRegistry) -> None:
await reg.record_install("plugin-github-sync")
entry = await reg.get_plugin("plugin-github-sync")
async def test_record_install_increments_count(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_install(db_session, "plugin-github-sync")
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_record_uninstall_decrements_count(self, reg: PluginRegistry) -> None:
await reg.record_install("plugin-github-sync")
await reg.record_install("plugin-github-sync")
await reg.record_uninstall("plugin-github-sync")
entry = await reg.get_plugin("plugin-github-sync")
async def test_record_uninstall_decrements_count(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_install(db_session, "plugin-github-sync")
await reg.record_install(db_session, "plugin-github-sync")
await reg.record_uninstall(db_session, "plugin-github-sync")
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_record_uninstall_floors_at_zero(self, reg: PluginRegistry) -> None:
await reg.record_uninstall("plugin-github-sync") # already 0
entry = await reg.get_plugin("plugin-github-sync")
async def test_record_uninstall_floors_at_zero(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_uninstall(db_session, "plugin-github-sync")
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 0
# ---------------------------------------------------------------------------
# ReviewQueue
# ReviewQueue (DB-backed)
# ---------------------------------------------------------------------------
@@ -188,37 +202,47 @@ class TestReviewQueue:
return PluginRegistry()
@pytest.fixture
def queue(self, reg: PluginRegistry) -> ReviewQueue:
# Patch the 'registry' name as bound inside plugin_review.py
with patch("app.marketplace.plugin_review.registry", reg):
yield ReviewQueue()
def queue(self) -> ReviewQueue:
return ReviewQueue()
@pytest.mark.asyncio
async def test_get_pending_returns_submitted_plugins(
self, reg: PluginRegistry, queue: ReviewQueue
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(manifest, "key.zip")
pending = await queue.get_pending()
await reg.submit_plugin(db_session, manifest, "key.zip")
pending = await queue.get_pending(db_session)
assert any(p["plugin_id"] == manifest.id for p in pending)
@pytest.mark.asyncio
async def test_submit_review_approved(
self, reg: PluginRegistry, queue: ReviewQueue
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(manifest, "key.zip")
await queue.submit_review(manifest.id, "reviewer-1", "approved", "Looks good")
assert reg._catalog[manifest.id]["status"] == "approved"
await reg.submit_plugin(db_session, manifest, "key.zip")
await queue.submit_review(db_session, manifest.id, TEST_USER_IDS["power"], "approved", "Looks good")
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
row = result.scalar_one()
assert row.status == "approved"
# Check review row was persisted
review_result = await db_session.execute(
select(PluginReviewModel).where(PluginReviewModel.plugin_id == manifest.id)
)
review = review_result.scalar_one()
assert review.decision == "approved"
@pytest.mark.asyncio
async def test_submit_review_rejected(
self, reg: PluginRegistry, queue: ReviewQueue
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(manifest, "key.zip")
await queue.submit_review(manifest.id, "reviewer-1", "rejected", "Bad permissions")
assert reg._catalog[manifest.id]["status"] == "rejected"
await reg.submit_plugin(db_session, manifest, "key.zip")
await queue.submit_review(
db_session, manifest.id, TEST_USER_IDS["power"], "rejected", "Bad permissions"
)
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
row = result.scalar_one()
assert row.status == "rejected"
def test_validate_manifest_ok(self) -> None:
manifest = _fresh_manifest(permissions=["read:tasks", "write:notes"])
@@ -241,65 +265,66 @@ class TestReviewQueue:
# ---------------------------------------------------------------------------
# RevenueShare
# RevenueShare (DB-backed)
# ---------------------------------------------------------------------------
class TestRevenueShare:
@pytest.fixture
def reg(self) -> PluginRegistry:
return PluginRegistry()
@pytest.fixture
def rs(self, reg: PluginRegistry) -> RevenueShare:
# Patch the 'registry' name as bound inside revenue_share.py
with patch("app.marketplace.revenue_share.registry", reg):
yield RevenueShare()
def rs(self) -> RevenueShare:
return RevenueShare()
@pytest.mark.asyncio
async def test_record_install_free_plugin(
self, reg: PluginRegistry, rs: RevenueShare
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await rs.record_install("plugin-github-sync", "user-1", amount_cents=0)
assert len(rs._events) == 1
assert rs._events[0]["developer_share_cents"] == 0
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
result = await db_session.execute(
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-github-sync")
)
event = result.scalar_one()
assert event.developer_share_cents == 0
@pytest.mark.asyncio
async def test_record_install_paid_plugin_no_stripe(
self, reg: PluginRegistry, rs: RevenueShare
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
# No STRIPE_SECRET_KEY configured in test env — should not crash
await rs.record_install("plugin-slack-notify", "user-2", amount_cents=499)
assert len(rs._events) == 1
assert rs._events[0]["amount_cents"] == 499
assert rs._events[0]["developer_share_cents"] == int(499 * 0.70)
await rs.record_install(
db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499
)
result = await db_session.execute(
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-slack-notify")
)
event = result.scalar_one()
assert event.amount_cents == 499
assert event.developer_share_cents == int(499 * 0.70)
@pytest.mark.asyncio
async def test_record_install_increments_registry_count(
self, reg: PluginRegistry, rs: RevenueShare
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await rs.record_install("plugin-github-sync", "user-1", amount_cents=0)
entry = await reg.get_plugin("plugin-github-sync")
reg = PluginRegistry()
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_get_earnings_empty(
self, reg: PluginRegistry, rs: RevenueShare
self, rs: RevenueShare, db_session: AsyncSession
) -> None:
result = await rs.get_earnings("unknown-dev")
result = await rs.get_earnings(db_session, "unknown-dev")
assert result["total_installs"] == 0
assert result["total_revenue_cents"] == 0
assert result["developer_share_cents"] == 0
@pytest.mark.asyncio
async def test_get_earnings_aggregates(
self, reg: PluginRegistry, rs: RevenueShare
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
# "Adiuva" is the author of the seeded plugins
await rs.record_install("plugin-slack-notify", "u1", amount_cents=499)
await rs.record_install("plugin-slack-notify", "u2", amount_cents=499)
result = await rs.get_earnings("Adiuva")
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["power"], amount_cents=499)
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499)
result = await rs.get_earnings(db_session, "Adiuva")
assert result["total_installs"] == 2
assert result["total_revenue_cents"] == 998
assert result["developer_share_cents"] == int(499 * 0.70) * 2
@@ -311,77 +336,67 @@ class TestRevenueShare:
class TestPluginRoutes:
def test_list_plugins_requires_power_tier(self) -> None:
with TestClient(app) as client:
resp = client.get("/api/v1/plugins", headers=_auth("free"))
def test_list_plugins_requires_power_tier(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("free"))
assert resp.status_code == 403
def test_list_plugins_pro_tier_blocked(self) -> None:
with TestClient(app) as client:
resp = client.get("/api/v1/plugins", headers=_auth("pro"))
def test_list_plugins_pro_tier_blocked(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("pro"))
assert resp.status_code == 403
def test_list_plugins_power_tier_ok(self) -> None:
with TestClient(app) as client:
resp = client.get("/api/v1/plugins", headers=_auth("power"))
def test_list_plugins_power_tier_ok(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("power"))
assert resp.status_code == 200
data = resp.json()
assert "plugins" in data
assert data["total"] >= 3
assert data["total"] == 3
def test_list_plugins_team_tier_ok(self) -> None:
with TestClient(app) as client:
resp = client.get("/api/v1/plugins", headers=_auth("team"))
def test_list_plugins_team_tier_ok(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("team"))
assert resp.status_code == 200
def test_get_plugin_found(self) -> None:
with TestClient(app) as client:
resp = client.get("/api/v1/plugins/plugin-github-sync", headers=_auth())
def test_get_plugin_found(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins/plugin-github-sync", headers=auth_header())
assert resp.status_code == 200
data = resp.json()
assert data["plugin"]["id"] == "plugin-github-sync"
assert "install_count" in data
def test_get_plugin_not_found(self) -> None:
with TestClient(app) as client:
resp = client.get("/api/v1/plugins/no-such-plugin", headers=_auth())
def test_get_plugin_not_found(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins/no-such-plugin", headers=auth_header())
assert resp.status_code == 404
def test_install_plugin_free(self) -> None:
with TestClient(app) as client:
resp = client.post(
"/api/v1/plugins/plugin-github-sync/install",
json={"plugin_id": "plugin-github-sync"},
headers=_auth(),
)
def test_install_plugin_free(self, client, seed_plugins) -> None:
resp = client.post(
"/api/v1/plugins/plugin-github-sync/install",
json={"plugin_id": "plugin-github-sync"},
headers=auth_header(),
)
assert resp.status_code == 200
data = resp.json()
assert data["ok"] is True
assert "download_url" in data
def test_install_plugin_not_found(self) -> None:
with TestClient(app) as client:
resp = client.post(
"/api/v1/plugins/ghost/install",
json={"plugin_id": "ghost"},
headers=_auth(),
)
def test_install_plugin_not_found(self, client, seed_plugins) -> None:
resp = client.post(
"/api/v1/plugins/ghost/install",
json={"plugin_id": "ghost"},
headers=auth_header(),
)
assert resp.status_code == 404
def test_uninstall_plugin_ok(self) -> None:
with TestClient(app) as client:
resp = client.delete(
"/api/v1/plugins/plugin-github-sync/install",
headers=_auth(),
)
def test_uninstall_plugin_ok(self, client, seed_plugins) -> None:
resp = client.delete(
"/api/v1/plugins/plugin-github-sync/install",
headers=auth_header(),
)
assert resp.status_code == 200
assert resp.json()["ok"] is True
def test_install_requires_power_tier(self) -> None:
with TestClient(app) as client:
resp = client.post(
"/api/v1/plugins/plugin-github-sync/install",
json={"plugin_id": "plugin-github-sync"},
headers=_auth("free"),
)
def test_install_requires_power_tier(self, client, seed_plugins) -> None:
resp = client.post(
"/api/v1/plugins/plugin-github-sync/install",
json={"plugin_id": "plugin-github-sync"},
headers=auth_header("free"),
)
assert resp.status_code == 403