Step 12 - completed
This commit is contained in:
@@ -439,7 +439,7 @@ adiuva-api/
|
||||
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
|
||||
|
||||
### Step 12 — Database (auth/billing/marketplace only)
|
||||
- [ ] PostgreSQL schema via Alembic:
|
||||
- [x] PostgreSQL schema via Alembic:
|
||||
- `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at`
|
||||
- `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at`
|
||||
- `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at`
|
||||
@@ -449,8 +449,8 @@ adiuva-api/
|
||||
- `plugin_installations`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `installed_at`
|
||||
- `plugin_reviews`: `id UUID PK`, `plugin_id FK`, `reviewer_id FK`, `decision`, `notes`, `reviewed_at`
|
||||
- `revenue_events`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `amount_cents`, `developer_share_cents`, `stripe_transfer_id`, `created_at`
|
||||
- [ ] Initial Alembic migration
|
||||
- [ ] SQLAlchemy models in `app/models.py`
|
||||
- [x] Initial Alembic migration
|
||||
- [x] SQLAlchemy models in `app/models.py`
|
||||
- **Outcome:** Auth, billing, storage metadata, and marketplace persistence. Zero user data in plaintext.
|
||||
|
||||
### Step 13 — Testing & deployment
|
||||
|
||||
92
alembic/versions/002_seed_plugins.py
Normal file
92
alembic/versions/002_seed_plugins.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Seed approved plugins: GitHub Sync, Slack Notifier, Time Tracker.
|
||||
|
||||
Revision ID: 002
|
||||
Revises: 001
|
||||
Create Date: 2026-03-03
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = "002"
|
||||
down_revision: Union[str, None] = "001"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
_SEED_PLUGINS = [
|
||||
{
|
||||
"id": "plugin-github-sync",
|
||||
"name": "GitHub Sync",
|
||||
"description": "Sync tasks with GitHub Issues and pull requests.",
|
||||
"version": "1.0.0",
|
||||
"author_name": "Adiuva",
|
||||
"category": "productivity",
|
||||
"price_cents": 0,
|
||||
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
||||
"status": "approved",
|
||||
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
|
||||
"install_count": 0,
|
||||
"avg_rating": 0.0,
|
||||
},
|
||||
{
|
||||
"id": "plugin-slack-notify",
|
||||
"name": "Slack Notifier",
|
||||
"description": "Post task and checkpoint updates to Slack channels.",
|
||||
"version": "1.2.0",
|
||||
"author_name": "Adiuva",
|
||||
"category": "communication",
|
||||
"price_cents": 499,
|
||||
"permissions": json.dumps(["read:tasks", "read:checkpoints"]),
|
||||
"status": "approved",
|
||||
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||
"install_count": 0,
|
||||
"avg_rating": 0.0,
|
||||
},
|
||||
{
|
||||
"id": "plugin-time-tracker",
|
||||
"name": "Time Tracker",
|
||||
"description": "Track time spent on tasks with automatic reporting.",
|
||||
"version": "0.9.1",
|
||||
"author_name": "Third Party",
|
||||
"category": "productivity",
|
||||
"price_cents": 999,
|
||||
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
||||
"status": "approved",
|
||||
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
|
||||
"install_count": 0,
|
||||
"avg_rating": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
plugins = sa.table(
|
||||
"plugins",
|
||||
sa.column("id", sa.String),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("description", sa.Text),
|
||||
sa.column("version", sa.String),
|
||||
sa.column("author_name", sa.String),
|
||||
sa.column("category", sa.String),
|
||||
sa.column("price_cents", sa.Integer),
|
||||
sa.column("permissions", sa.Text),
|
||||
sa.column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status")),
|
||||
sa.column("s3_package_key", sa.String),
|
||||
sa.column("install_count", sa.Integer),
|
||||
sa.column("avg_rating", sa.Float),
|
||||
)
|
||||
op.bulk_insert(plugins, _SEED_PLUGINS)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"DELETE FROM plugins WHERE id IN ("
|
||||
"'plugin-github-sync', 'plugin-slack-notify', 'plugin-time-tracker'"
|
||||
")"
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
|
||||
|
||||
Blobs are stored in S3 via BlobStore. Backup metadata is kept in an
|
||||
in-memory dict until Step 12 migrates it to PostgreSQL (backup_metadata table).
|
||||
Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the
|
||||
PostgreSQL ``backup_metadata`` table.
|
||||
|
||||
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
|
||||
treating "history" as a ``{backup_id}`` path parameter.
|
||||
@@ -9,14 +9,17 @@ treating "history" as a ``{backup_id}`` path parameter.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.tier_manager import tier_manager
|
||||
from app.db import get_session
|
||||
from app.models import BackupMetadata as BackupMetadataModel
|
||||
from app.schemas import BackupMetadata, UserProfile
|
||||
from app.storage.blob_store import BlobStore
|
||||
from app.storage.encryption import reject_if_tampered
|
||||
@@ -25,14 +28,25 @@ router = APIRouter(prefix="/backup", tags=["backup"])
|
||||
|
||||
_blob_store = BlobStore()
|
||||
|
||||
# In-memory backup metadata — replaced by PostgreSQL backup_metadata table in Step 12
|
||||
_backups: dict[str, list[dict[str, Any]]] = {} # user_id → list of backup records
|
||||
|
||||
async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int:
|
||||
"""Return total backup bytes stored by *user_id*."""
|
||||
result = await db.execute(
|
||||
select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where(
|
||||
BackupMetadataModel.user_id == user_id
|
||||
)
|
||||
)
|
||||
return int(result.scalar_one())
|
||||
|
||||
|
||||
def _check_backup_quota(user_id: str, size_bytes: int) -> None:
|
||||
async def _check_backup_quota(
|
||||
user: UserProfile, size_bytes: int, db: AsyncSession
|
||||
) -> None:
|
||||
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
|
||||
current = sum(b["size_bytes"] for b in _backups.get(user_id, []))
|
||||
tier_manager.enforce_backup_quota(user_id, current_bytes=current, additional_bytes=size_bytes)
|
||||
current = await _current_backup_bytes(user.id, db)
|
||||
tier_manager.enforce_backup_quota(
|
||||
user.tier, current_bytes=current, additional_bytes=size_bytes
|
||||
)
|
||||
|
||||
|
||||
@router.put("")
|
||||
@@ -42,6 +56,7 @@ async def upload_backup(
|
||||
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
|
||||
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Upload an E2E-encrypted backup blob.
|
||||
|
||||
@@ -49,24 +64,23 @@ async def upload_backup(
|
||||
"""
|
||||
blob = await request.body()
|
||||
reject_if_tampered(blob, x_backup_checksum)
|
||||
_check_backup_quota(current_user.id, len(blob))
|
||||
await _check_backup_quota(current_user, len(blob), db)
|
||||
|
||||
s3_key = await _blob_store.upload(
|
||||
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
|
||||
)
|
||||
|
||||
backup_record: dict[str, Any] = {
|
||||
"id": str(x_backup_timestamp),
|
||||
"s3_key": s3_key,
|
||||
"version": x_backup_version,
|
||||
"timestamp": x_backup_timestamp,
|
||||
"checksum": x_backup_checksum,
|
||||
"size_bytes": len(blob),
|
||||
}
|
||||
|
||||
user_backups = _backups.setdefault(current_user.id, [])
|
||||
user_backups.append(backup_record)
|
||||
user_backups.sort(key=lambda b: b["timestamp"], reverse=True)
|
||||
row = BackupMetadataModel(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=current_user.id,
|
||||
s3_key=s3_key,
|
||||
version=x_backup_version,
|
||||
timestamp=x_backup_timestamp,
|
||||
checksum=x_backup_checksum,
|
||||
size_bytes=len(blob),
|
||||
)
|
||||
db.add(row)
|
||||
await db.commit()
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
@@ -74,16 +88,23 @@ async def upload_backup(
|
||||
@router.get("/history", response_model=list[BackupMetadata])
|
||||
async def backup_history(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[BackupMetadata]:
|
||||
"""Return backup metadata records for the authenticated user (no blob bytes)."""
|
||||
result = await db.execute(
|
||||
select(BackupMetadataModel)
|
||||
.where(BackupMetadataModel.user_id == current_user.id)
|
||||
.order_by(BackupMetadataModel.timestamp.desc())
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
return [
|
||||
BackupMetadata(
|
||||
version=b["version"],
|
||||
timestamp=b["timestamp"],
|
||||
checksum=b["checksum"],
|
||||
chunk_count=1, # single-chunk uploads for now — TODO(Step12): track real count
|
||||
version=r.version,
|
||||
timestamp=r.timestamp,
|
||||
checksum=r.checksum,
|
||||
chunk_count=1,
|
||||
)
|
||||
for b in _backups.get(current_user.id, [])
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@@ -91,32 +112,37 @@ async def backup_history(
|
||||
async def download_backup(
|
||||
request: Request,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
|
||||
user_backups = _backups.get(current_user.id, [])
|
||||
if not user_backups:
|
||||
result = await db.execute(
|
||||
select(BackupMetadataModel)
|
||||
.where(BackupMetadataModel.user_id == current_user.id)
|
||||
.order_by(BackupMetadataModel.timestamp.desc())
|
||||
.limit(1)
|
||||
)
|
||||
latest = result.scalar_one_or_none()
|
||||
if latest is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
|
||||
|
||||
latest = user_backups[0]
|
||||
|
||||
ims_header = request.headers.get("If-Modified-Since")
|
||||
if ims_header:
|
||||
try:
|
||||
ims_dt = parsedate_to_datetime(ims_header)
|
||||
ims_ms = int(ims_dt.timestamp() * 1000)
|
||||
if latest["timestamp"] <= ims_ms:
|
||||
if latest.timestamp <= ims_ms:
|
||||
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
|
||||
except Exception:
|
||||
pass # malformed header — ignore and serve the blob
|
||||
|
||||
blob = await _blob_store.download(current_user.id, latest["s3_key"])
|
||||
blob = await _blob_store.download(current_user.id, latest.s3_key)
|
||||
return Response(
|
||||
content=blob,
|
||||
media_type="application/octet-stream",
|
||||
headers={
|
||||
"X-Backup-Version": str(latest["version"]),
|
||||
"X-Backup-Timestamp": str(latest["timestamp"]),
|
||||
"X-Checksum": latest["checksum"],
|
||||
"X-Backup-Version": str(latest.version),
|
||||
"X-Backup-Timestamp": str(latest.timestamp),
|
||||
"X-Checksum": latest.checksum,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -125,14 +151,21 @@ async def download_backup(
|
||||
async def delete_backup(
|
||||
backup_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Delete a specific backup by ID."""
|
||||
user_backups = _backups.get(current_user.id, [])
|
||||
target = next((b for b in user_backups if b["id"] == backup_id), None)
|
||||
result = await db.execute(
|
||||
select(BackupMetadataModel).where(
|
||||
BackupMetadataModel.id == backup_id,
|
||||
BackupMetadataModel.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
target = result.scalar_one_or_none()
|
||||
if target is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
|
||||
|
||||
await _blob_store.delete(current_user.id, target["s3_key"])
|
||||
_backups[current_user.id] = [b for b in user_backups if b["id"] != backup_id]
|
||||
await _blob_store.delete(current_user.id, target.s3_key)
|
||||
await db.delete(target)
|
||||
await db.commit()
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"""Plugins routes: browse and install plugins from the marketplace.
|
||||
|
||||
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes introduced
|
||||
in Step 10. Step 12 will swap those services' in-memory stores for
|
||||
PostgreSQL persistence.
|
||||
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that
|
||||
persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -11,10 +10,14 @@ from typing import Any, Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.db import get_session
|
||||
from app.marketplace.plugin_registry import registry
|
||||
from app.marketplace.revenue_share import revenue_share
|
||||
from app.models import PluginInstallation, PluginReview as PluginReviewModel
|
||||
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/plugins", tags=["plugins"])
|
||||
@@ -36,7 +39,7 @@ def _require_plugin_tier(user: UserProfile) -> None:
|
||||
class _PluginDetail(BaseModel):
|
||||
plugin: PluginManifest
|
||||
install_count: int
|
||||
ratings: list[Any] # Step 12 populates from plugin_reviews table
|
||||
ratings: list[Any]
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
@@ -48,26 +51,44 @@ async def list_plugins(
|
||||
page: int = Query(default=1, ge=1),
|
||||
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> PluginListResponse:
|
||||
"""Browse the plugin marketplace. Requires Power tier or above."""
|
||||
_require_plugin_tier(current_user)
|
||||
return await registry.list_plugins(category=category, query=q, page=page, sort=sort)
|
||||
return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort)
|
||||
|
||||
|
||||
@router.get("/{plugin_id}", response_model=_PluginDetail)
|
||||
async def get_plugin(
|
||||
plugin_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> _PluginDetail:
|
||||
"""Get full plugin details including install count. Requires Power tier or above."""
|
||||
_require_plugin_tier(current_user)
|
||||
entry = await registry.get_plugin(plugin_id)
|
||||
entry = await registry.get_plugin(db, plugin_id)
|
||||
if entry is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||
|
||||
# Fetch review ratings for this plugin
|
||||
review_result = await db.execute(
|
||||
select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id)
|
||||
)
|
||||
reviews = review_result.scalars().all()
|
||||
ratings = [
|
||||
{
|
||||
"reviewer_id": r.reviewer_id,
|
||||
"decision": r.decision,
|
||||
"notes": r.notes,
|
||||
"reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None,
|
||||
}
|
||||
for r in reviews
|
||||
]
|
||||
|
||||
return _PluginDetail(
|
||||
plugin=entry["manifest"],
|
||||
install_count=entry["install_count"],
|
||||
ratings=[], # Step 12 populates from plugin_reviews table
|
||||
ratings=ratings,
|
||||
)
|
||||
|
||||
|
||||
@@ -76,17 +97,27 @@ async def install_plugin(
|
||||
plugin_id: str,
|
||||
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, Any]:
|
||||
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
|
||||
|
||||
Requires Power tier or above.
|
||||
"""
|
||||
_require_plugin_tier(current_user)
|
||||
entry = await registry.get_plugin(plugin_id)
|
||||
entry = await registry.get_plugin(db, plugin_id)
|
||||
if entry is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||
|
||||
# Record the installation in plugin_installations
|
||||
installation = PluginInstallation(
|
||||
plugin_id=plugin_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
db.add(installation)
|
||||
await db.flush()
|
||||
|
||||
await revenue_share.record_install(
|
||||
db,
|
||||
plugin_id=plugin_id,
|
||||
user_id=current_user.id,
|
||||
amount_cents=entry["manifest"].price_cents,
|
||||
@@ -100,7 +131,18 @@ async def install_plugin(
|
||||
async def uninstall_plugin(
|
||||
plugin_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Unregister a plugin installation."""
|
||||
await registry.record_uninstall(plugin_id)
|
||||
result = await db.execute(
|
||||
select(PluginInstallation).where(
|
||||
PluginInstallation.plugin_id == plugin_id,
|
||||
PluginInstallation.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
installation = result.scalar_one_or_none()
|
||||
if installation is not None:
|
||||
await db.delete(installation)
|
||||
await db.commit()
|
||||
await registry.record_uninstall(db, plugin_id)
|
||||
return {"ok": True}
|
||||
|
||||
@@ -1,20 +1,23 @@
|
||||
"""Storage routes: CRUD for E2E-encrypted cloud records.
|
||||
|
||||
Blobs are stored in S3 via BlobStore. Record metadata is kept in an
|
||||
in-memory dict until Step 12 migrates it to PostgreSQL (storage_records table).
|
||||
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
|
||||
PostgreSQL ``storage_records`` table.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.tier_manager import tier_manager
|
||||
from app.db import get_session
|
||||
from app.models import StorageRecord
|
||||
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
|
||||
from app.storage.blob_store import BlobStore
|
||||
from app.storage.encryption import reject_if_tampered
|
||||
@@ -23,9 +26,6 @@ router = APIRouter(prefix="/storage", tags=["storage"])
|
||||
|
||||
_blob_store = BlobStore()
|
||||
|
||||
# In-memory record metadata — replaced by PostgreSQL storage_records table in Step 12
|
||||
_records: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
# ── Local response schemas ─────────────────────────────────────────────
|
||||
|
||||
@@ -44,17 +44,34 @@ class _RecordMeta(BaseModel):
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
def _check_quota(user_id: str, additional_bytes: int) -> None:
|
||||
"""Raise HTTP 402 if adding ``additional_bytes`` would exceed the tier limit."""
|
||||
current = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id)
|
||||
tier_manager.enforce_quota(user_id, current_bytes=current, additional_bytes=additional_bytes)
|
||||
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
|
||||
"""Return total bytes stored by *user_id*."""
|
||||
result = await db.execute(
|
||||
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
|
||||
StorageRecord.user_id == user_id
|
||||
)
|
||||
)
|
||||
return int(result.scalar_one())
|
||||
|
||||
|
||||
def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]:
|
||||
"""Look up a record and verify ownership. Always returns 404 on mismatch
|
||||
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
|
||||
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
|
||||
current = await _current_usage_bytes(user.id, db)
|
||||
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
|
||||
|
||||
|
||||
async def _get_record_for_user(
|
||||
record_id: str, user_id: str, db: AsyncSession
|
||||
) -> StorageRecord:
|
||||
"""Look up a record and verify ownership. Returns 404 on mismatch
|
||||
to prevent user enumeration attacks."""
|
||||
record = _records.get(record_id)
|
||||
if record is None or record["user_id"] != user_id:
|
||||
result = await db.execute(
|
||||
select(StorageRecord).where(
|
||||
StorageRecord.id == record_id, StorageRecord.user_id == user_id
|
||||
)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if record is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
|
||||
return record
|
||||
|
||||
@@ -65,30 +82,32 @@ def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]:
|
||||
async def create_record(
|
||||
body: StorageRecordCreate,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> _CreateResponse:
|
||||
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
|
||||
reject_if_tampered(body.blob, body.checksum)
|
||||
_check_quota(current_user.id, len(body.blob))
|
||||
await _check_quota(current_user, len(body.blob), db)
|
||||
|
||||
record_id = str(uuid.uuid4())
|
||||
now = int(time.time() * 1000)
|
||||
|
||||
s3_key = await _blob_store.upload(
|
||||
current_user.id, body.table, record_id, body.blob, body.checksum
|
||||
)
|
||||
|
||||
_records[record_id] = {
|
||||
"id": record_id,
|
||||
"user_id": current_user.id,
|
||||
"table": body.table,
|
||||
"s3_key": s3_key,
|
||||
"checksum": body.checksum,
|
||||
"size_bytes": len(body.blob),
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
record = StorageRecord(
|
||||
id=record_id,
|
||||
user_id=current_user.id,
|
||||
table_name=body.table,
|
||||
s3_key=s3_key,
|
||||
checksum=body.checksum,
|
||||
size_bytes=len(body.blob),
|
||||
)
|
||||
db.add(record)
|
||||
await db.commit()
|
||||
await db.refresh(record)
|
||||
|
||||
return _CreateResponse(id=record_id, created_at=now)
|
||||
created_at_ms = int(record.created_at.timestamp() * 1000)
|
||||
return _CreateResponse(id=record_id, created_at=created_at_ms)
|
||||
|
||||
|
||||
@router.get("/records", response_model=list[_RecordMeta])
|
||||
@@ -97,23 +116,26 @@ async def list_records(
|
||||
page: int = Query(default=1, ge=1),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[_RecordMeta]:
|
||||
"""List record metadata for the authenticated user. Blob bytes are never returned."""
|
||||
all_records = [
|
||||
r for r in _records.values()
|
||||
if r["user_id"] == current_user.id and (table is None or r["table"] == table)
|
||||
]
|
||||
start = (page - 1) * limit
|
||||
page_records = all_records[start : start + limit]
|
||||
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
|
||||
if table is not None:
|
||||
query = query.where(StorageRecord.table_name == table)
|
||||
query = query.offset((page - 1) * limit).limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.scalars().all()
|
||||
|
||||
return [
|
||||
_RecordMeta(
|
||||
id=r["id"],
|
||||
table=r["table"],
|
||||
checksum=r["checksum"],
|
||||
created_at=r["created_at"],
|
||||
updated_at=r["updated_at"],
|
||||
id=r.id,
|
||||
table=r.table_name,
|
||||
checksum=r.checksum,
|
||||
created_at=int(r.created_at.timestamp() * 1000),
|
||||
updated_at=int(r.updated_at.timestamp() * 1000),
|
||||
)
|
||||
for r in page_records
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@@ -121,14 +143,15 @@ async def list_records(
|
||||
async def download_record(
|
||||
record_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
|
||||
record = _get_record_for_user(record_id, current_user.id)
|
||||
blob = await _blob_store.download(current_user.id, record["s3_key"])
|
||||
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||
blob = await _blob_store.download(current_user.id, record.s3_key)
|
||||
return Response(
|
||||
content=blob,
|
||||
media_type="application/octet-stream",
|
||||
headers={"X-Checksum": record["checksum"]},
|
||||
headers={"X-Checksum": record.checksum},
|
||||
)
|
||||
|
||||
|
||||
@@ -137,23 +160,24 @@ async def update_record(
|
||||
record_id: str,
|
||||
body: StorageRecordUpdate,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Replace the blob for an existing record. Verifies checksum before storing."""
|
||||
record = _get_record_for_user(record_id, current_user.id)
|
||||
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||
reject_if_tampered(body.blob, body.checksum)
|
||||
|
||||
delta = len(body.blob) - record["size_bytes"]
|
||||
delta = len(body.blob) - record.size_bytes
|
||||
if delta > 0:
|
||||
_check_quota(current_user.id, delta)
|
||||
await _check_quota(current_user, delta, db)
|
||||
|
||||
s3_key = await _blob_store.upload(
|
||||
current_user.id, record["table"], record_id, body.blob, body.checksum
|
||||
current_user.id, record.table_name, record_id, body.blob, body.checksum
|
||||
)
|
||||
|
||||
record["s3_key"] = s3_key
|
||||
record["checksum"] = body.checksum
|
||||
record["size_bytes"] = len(body.blob)
|
||||
record["updated_at"] = int(time.time() * 1000)
|
||||
record.s3_key = s3_key
|
||||
record.checksum = body.checksum
|
||||
record.size_bytes = len(body.blob)
|
||||
await db.commit()
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
@@ -162,9 +186,11 @@ async def update_record(
|
||||
async def delete_record(
|
||||
record_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Delete a record and its S3 blob."""
|
||||
record = _get_record_for_user(record_id, current_user.id)
|
||||
await _blob_store.delete(current_user.id, record["s3_key"])
|
||||
del _records[record_id]
|
||||
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||
await _blob_store.delete(current_user.id, record.s3_key)
|
||||
await db.delete(record)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
"""Plugin catalog registry.
|
||||
"""Plugin catalog registry backed by PostgreSQL.
|
||||
|
||||
Maintains the authoritative list of plugins, their review status, and
|
||||
aggregate install counts. Storage is in-memory until Step 12 migrates to
|
||||
the ``plugins`` PostgreSQL table.
|
||||
aggregate install counts. All data is persisted in the ``plugins`` table.
|
||||
|
||||
Module-level singleton::
|
||||
|
||||
@@ -11,144 +10,103 @@ Module-level singleton::
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import time
|
||||
import uuid
|
||||
import json
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import Plugin
|
||||
from app.schemas import PluginListResponse, PluginManifest
|
||||
|
||||
# ── Pre-seeded approved plugins (mirrors the Step 8 stub catalog) ─────
|
||||
|
||||
_SEED_PLUGINS: list[dict[str, Any]] = [
|
||||
{
|
||||
"manifest": PluginManifest(
|
||||
id="plugin-github-sync",
|
||||
name="GitHub Sync",
|
||||
description="Sync tasks with GitHub Issues and pull requests.",
|
||||
version="1.0.0",
|
||||
author="Adiuva",
|
||||
permissions=["read:tasks", "write:tasks"],
|
||||
category="productivity",
|
||||
price_cents=0,
|
||||
),
|
||||
"status": "approved",
|
||||
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
|
||||
"install_count": 0,
|
||||
"avg_rating": 0.0,
|
||||
"rejection_reason": None,
|
||||
"submitted_at": int(time.time()),
|
||||
},
|
||||
{
|
||||
"manifest": PluginManifest(
|
||||
id="plugin-slack-notify",
|
||||
name="Slack Notifier",
|
||||
description="Post task and checkpoint updates to Slack channels.",
|
||||
version="1.2.0",
|
||||
author="Adiuva",
|
||||
permissions=["read:tasks", "read:checkpoints"],
|
||||
category="communication",
|
||||
price_cents=499,
|
||||
),
|
||||
"status": "approved",
|
||||
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||
"install_count": 0,
|
||||
"avg_rating": 0.0,
|
||||
"rejection_reason": None,
|
||||
"submitted_at": int(time.time()),
|
||||
},
|
||||
{
|
||||
"manifest": PluginManifest(
|
||||
id="plugin-time-tracker",
|
||||
name="Time Tracker",
|
||||
description="Track time spent on tasks with automatic reporting.",
|
||||
version="0.9.1",
|
||||
author="Third Party",
|
||||
permissions=["read:tasks", "write:tasks"],
|
||||
category="productivity",
|
||||
price_cents=999,
|
||||
),
|
||||
"status": "approved",
|
||||
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
|
||||
"install_count": 0,
|
||||
"avg_rating": 0.0,
|
||||
"rejection_reason": None,
|
||||
"submitted_at": int(time.time()),
|
||||
},
|
||||
]
|
||||
|
||||
_PAGE_SIZE = 20
|
||||
|
||||
|
||||
def _plugin_to_manifest(p: Plugin) -> PluginManifest:
|
||||
"""Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``."""
|
||||
try:
|
||||
permissions = json.loads(p.permissions) if p.permissions else []
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
permissions = []
|
||||
return PluginManifest(
|
||||
id=p.id,
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
version=p.version,
|
||||
author=p.author_name,
|
||||
permissions=permissions,
|
||||
category=p.category,
|
||||
price_cents=p.price_cents,
|
||||
)
|
||||
|
||||
|
||||
class PluginRegistry:
|
||||
"""In-process plugin catalog.
|
||||
"""PostgreSQL-backed plugin catalog.
|
||||
|
||||
All mutating methods are ``async`` to make the future DB swap transparent
|
||||
to callers.
|
||||
All methods accept an ``AsyncSession`` parameter so the calling route
|
||||
controls the session lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# plugin_id → entry dict (deep-copied so each instance is independent)
|
||||
self._catalog: dict[str, dict[str, Any]] = {
|
||||
e["manifest"].id: copy.deepcopy(e) for e in _SEED_PLUGINS
|
||||
}
|
||||
|
||||
# ── Queries ──────────────────────────────────────────────────────
|
||||
|
||||
async def list_plugins(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
category: str | None = None,
|
||||
query: str | None = None,
|
||||
page: int = 1,
|
||||
sort: Literal["rating", "installs", "newest"] = "newest",
|
||||
) -> PluginListResponse:
|
||||
"""Return a page of approved plugins, optionally filtered and sorted."""
|
||||
entries = [e for e in self._catalog.values() if e["status"] == "approved"]
|
||||
base = select(Plugin).where(Plugin.status == "approved")
|
||||
|
||||
if category:
|
||||
entries = [e for e in entries if e["manifest"].category == category]
|
||||
|
||||
base = base.where(Plugin.category == category)
|
||||
if query:
|
||||
q_lower = query.lower()
|
||||
entries = [
|
||||
e
|
||||
for e in entries
|
||||
if q_lower in e["manifest"].name.lower()
|
||||
or q_lower in e["manifest"].description.lower()
|
||||
]
|
||||
pattern = f"%{query}%"
|
||||
base = base.where(
|
||||
Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern)
|
||||
)
|
||||
|
||||
# Count
|
||||
count_q = select(func.count()).select_from(base.subquery())
|
||||
total = (await db.execute(count_q)).scalar_one()
|
||||
|
||||
# Sort
|
||||
if sort == "installs":
|
||||
entries = sorted(entries, key=lambda e: e["install_count"], reverse=True)
|
||||
base = base.order_by(Plugin.install_count.desc())
|
||||
elif sort == "rating":
|
||||
entries = sorted(entries, key=lambda e: e["avg_rating"], reverse=True)
|
||||
# "newest" = catalog insertion order (dict preserves insertion in Python 3.7+)
|
||||
base = base.order_by(Plugin.avg_rating.desc())
|
||||
else: # newest
|
||||
base = base.order_by(Plugin.created_at.desc())
|
||||
|
||||
total = len(entries)
|
||||
start = (page - 1) * _PAGE_SIZE
|
||||
page_entries = entries[start : start + _PAGE_SIZE]
|
||||
base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE)
|
||||
rows = (await db.execute(base)).scalars().all()
|
||||
|
||||
return PluginListResponse(
|
||||
plugins=[e["manifest"] for e in page_entries],
|
||||
plugins=[_plugin_to_manifest(r) for r in rows],
|
||||
total=total,
|
||||
page=page,
|
||||
)
|
||||
|
||||
async def get_plugin(self, plugin_id: str) -> dict[str, Any] | None:
|
||||
async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None:
|
||||
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
|
||||
entry = self._catalog.get(plugin_id)
|
||||
if entry is None:
|
||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
p = result.scalar_one_or_none()
|
||||
if p is None:
|
||||
return None
|
||||
return {
|
||||
"manifest": entry["manifest"],
|
||||
"status": entry["status"],
|
||||
"install_count": entry["install_count"],
|
||||
"avg_rating": entry["avg_rating"],
|
||||
"manifest": _plugin_to_manifest(p),
|
||||
"status": p.status,
|
||||
"install_count": p.install_count,
|
||||
"avg_rating": p.avg_rating,
|
||||
}
|
||||
|
||||
# ── Mutations ────────────────────────────────────────────────────
|
||||
|
||||
async def submit_plugin(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
manifest: PluginManifest,
|
||||
package_s3_key: str,
|
||||
) -> str:
|
||||
@@ -157,54 +115,97 @@ class PluginRegistry:
|
||||
Returns the plugin_id. If a plugin with the same id already exists
|
||||
it is overwritten (re-submission after rejection).
|
||||
"""
|
||||
plugin_id = manifest.id or str(uuid.uuid4())
|
||||
self._catalog[plugin_id] = {
|
||||
"manifest": manifest,
|
||||
"status": "pending_review",
|
||||
"s3_package_key": package_s3_key,
|
||||
"install_count": 0,
|
||||
"avg_rating": 0.0,
|
||||
"rejection_reason": None,
|
||||
"submitted_at": int(time.time()),
|
||||
}
|
||||
plugin_id = manifest.id
|
||||
existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
row = existing.scalar_one_or_none()
|
||||
|
||||
if row is not None:
|
||||
row.name = manifest.name
|
||||
row.description = manifest.description
|
||||
row.version = manifest.version
|
||||
row.author_name = manifest.author
|
||||
row.category = manifest.category
|
||||
row.price_cents = manifest.price_cents
|
||||
row.permissions = json.dumps(manifest.permissions)
|
||||
row.status = "pending_review"
|
||||
row.s3_package_key = package_s3_key
|
||||
row.rejection_reason = None
|
||||
else:
|
||||
row = Plugin(
|
||||
id=plugin_id,
|
||||
name=manifest.name,
|
||||
description=manifest.description,
|
||||
version=manifest.version,
|
||||
author_name=manifest.author,
|
||||
category=manifest.category,
|
||||
price_cents=manifest.price_cents,
|
||||
permissions=json.dumps(manifest.permissions),
|
||||
status="pending_review",
|
||||
s3_package_key=package_s3_key,
|
||||
install_count=0,
|
||||
avg_rating=0.0,
|
||||
)
|
||||
db.add(row)
|
||||
await db.commit()
|
||||
return plugin_id
|
||||
|
||||
async def approve_plugin(self, plugin_id: str) -> None:
|
||||
async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None:
|
||||
"""Set *plugin_id* status to ``'approved'``.
|
||||
|
||||
Raises ``KeyError`` if the plugin is not found.
|
||||
"""
|
||||
if plugin_id not in self._catalog:
|
||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||
self._catalog[plugin_id]["status"] = "approved"
|
||||
self._catalog[plugin_id]["rejection_reason"] = None
|
||||
row.status = "approved"
|
||||
row.rejection_reason = None
|
||||
await db.commit()
|
||||
|
||||
async def reject_plugin(self, plugin_id: str, reason: str) -> None:
|
||||
async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None:
|
||||
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
|
||||
|
||||
Raises ``KeyError`` if the plugin is not found.
|
||||
"""
|
||||
if plugin_id not in self._catalog:
|
||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||
self._catalog[plugin_id]["status"] = "rejected"
|
||||
self._catalog[plugin_id]["rejection_reason"] = reason
|
||||
row.status = "rejected"
|
||||
row.rejection_reason = reason
|
||||
await db.commit()
|
||||
|
||||
async def record_install(self, plugin_id: str) -> None:
|
||||
async def record_install(self, db: AsyncSession, plugin_id: str) -> None:
|
||||
"""Increment the install count for *plugin_id* (no-op if not found)."""
|
||||
if plugin_id in self._catalog:
|
||||
self._catalog[plugin_id]["install_count"] += 1
|
||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
row = result.scalar_one_or_none()
|
||||
if row is not None:
|
||||
row.install_count = row.install_count + 1
|
||||
await db.commit()
|
||||
|
||||
async def record_uninstall(self, plugin_id: str) -> None:
|
||||
async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None:
|
||||
"""Decrement the install count for *plugin_id*, floored at 0."""
|
||||
if plugin_id in self._catalog:
|
||||
current = self._catalog[plugin_id]["install_count"]
|
||||
self._catalog[plugin_id]["install_count"] = max(0, current - 1)
|
||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
row = result.scalar_one_or_none()
|
||||
if row is not None:
|
||||
row.install_count = max(0, row.install_count - 1)
|
||||
await db.commit()
|
||||
|
||||
# ── Internal helpers used by ReviewQueue ─────────────────────────
|
||||
|
||||
def _get_pending_entries(self) -> list[dict[str, Any]]:
|
||||
"""Return all entries with status='pending_review' (synchronous helper)."""
|
||||
return [e for e in self._catalog.values() if e["status"] == "pending_review"]
|
||||
async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]:
|
||||
"""Return all entries with status='pending_review'."""
|
||||
result = await db.execute(
|
||||
select(Plugin).where(Plugin.status == "pending_review")
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"manifest": _plugin_to_manifest(r),
|
||||
"submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Plugin review workflow.
|
||||
"""Plugin review workflow backed by PostgreSQL.
|
||||
|
||||
Manages the approval queue for newly submitted plugins and enforces a
|
||||
security checklist before any plugin is made visible in the marketplace.
|
||||
@@ -11,10 +11,12 @@ Module-level singleton::
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Literal
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.marketplace.plugin_registry import registry
|
||||
from app.models import PluginReview as PluginReviewModel
|
||||
from app.schemas import PluginManifest
|
||||
|
||||
# ── Security policy ───────────────────────────────────────────────────
|
||||
@@ -72,20 +74,16 @@ def validate_manifest(manifest: PluginManifest) -> None:
|
||||
class ReviewQueue:
|
||||
"""Approval queue for pending plugin submissions.
|
||||
|
||||
Delegates status changes to the shared ``PluginRegistry`` singleton so
|
||||
there is a single source of truth for plugin state.
|
||||
Delegates status changes to the shared ``PluginRegistry`` singleton.
|
||||
Review records are persisted in the ``plugin_reviews`` table.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Completed reviews — Step 12 stores in plugin_reviews table
|
||||
self._reviews: list[dict[str, Any]] = []
|
||||
|
||||
async def get_pending(self) -> list[dict[str, Any]]:
|
||||
async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]:
|
||||
"""Return all plugins currently awaiting review.
|
||||
|
||||
Each item is ``{plugin_id, manifest, submitted_at}``.
|
||||
"""
|
||||
entries = registry._get_pending_entries()
|
||||
entries = await registry.get_pending_entries(db)
|
||||
return [
|
||||
{
|
||||
"plugin_id": e["manifest"].id,
|
||||
@@ -97,6 +95,7 @@ class ReviewQueue:
|
||||
|
||||
async def submit_review(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
plugin_id: str,
|
||||
reviewer_id: str,
|
||||
decision: Literal["approved", "rejected"],
|
||||
@@ -108,19 +107,18 @@ class ReviewQueue:
|
||||
``KeyError`` if *plugin_id* is not found in the registry.
|
||||
"""
|
||||
if decision == "approved":
|
||||
await registry.approve_plugin(plugin_id)
|
||||
await registry.approve_plugin(db, plugin_id)
|
||||
else:
|
||||
await registry.reject_plugin(plugin_id, reason=notes)
|
||||
await registry.reject_plugin(db, plugin_id, reason=notes)
|
||||
|
||||
self._reviews.append(
|
||||
{
|
||||
"plugin_id": plugin_id,
|
||||
"reviewer_id": reviewer_id,
|
||||
"decision": decision,
|
||||
"notes": notes,
|
||||
"reviewed_at": int(time.time()),
|
||||
}
|
||||
review = PluginReviewModel(
|
||||
plugin_id=plugin_id,
|
||||
reviewer_id=reviewer_id,
|
||||
decision=decision,
|
||||
notes=notes,
|
||||
)
|
||||
db.add(review)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Revenue share tracking and Stripe Connect payouts.
|
||||
"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL.
|
||||
|
||||
Records every plugin installation as a revenue event and facilitates
|
||||
70 % / 30 % payouts to developers via Stripe Connect. Storage is
|
||||
in-memory until Step 12 migrates to the ``revenue_events`` table.
|
||||
70 % / 30 % payouts to developers via Stripe Connect. Data is persisted
|
||||
in the ``revenue_events`` table.
|
||||
|
||||
Module-level singleton::
|
||||
|
||||
@@ -12,13 +12,16 @@ Module-level singleton::
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import stripe as stripe_lib
|
||||
from sqlalchemy import extract, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.marketplace.plugin_registry import registry
|
||||
from app.models import Plugin, RevenueEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,10 +38,6 @@ class RevenueShare:
|
||||
is not configured, consistent with the rest of the billing layer.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Step 12 replaces with revenue_events DB table
|
||||
self._events: list[dict[str, Any]] = []
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
@@ -54,6 +53,7 @@ class RevenueShare:
|
||||
|
||||
async def record_install(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
plugin_id: str,
|
||||
user_id: str,
|
||||
amount_cents: int,
|
||||
@@ -72,11 +72,12 @@ class RevenueShare:
|
||||
stripe_transfer_id: str | None = None
|
||||
|
||||
if amount_cents > 0 and self._stripe_configured():
|
||||
plugin_entry = registry._catalog.get(plugin_id)
|
||||
# Look up the plugin's author Stripe account from the DB
|
||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
plugin_row = result.scalar_one_or_none()
|
||||
developer_stripe_account: str | None = None
|
||||
if plugin_entry:
|
||||
# Step 12: look up developer's Stripe account from DB
|
||||
# For now, the author field is used as a placeholder key.
|
||||
if plugin_row and plugin_row.author_id:
|
||||
# Future: look up user.stripe_connect_account_id
|
||||
developer_stripe_account = None # no real account yet
|
||||
|
||||
if developer_stripe_account:
|
||||
@@ -103,22 +104,21 @@ class RevenueShare:
|
||||
plugin_id,
|
||||
)
|
||||
|
||||
self._events.append(
|
||||
{
|
||||
"plugin_id": plugin_id,
|
||||
"user_id": user_id,
|
||||
"amount_cents": amount_cents,
|
||||
"developer_share_cents": developer_share_cents,
|
||||
"stripe_transfer_id": stripe_transfer_id,
|
||||
"paid_at": None,
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
event = RevenueEvent(
|
||||
plugin_id=plugin_id,
|
||||
user_id=user_id,
|
||||
amount_cents=amount_cents,
|
||||
developer_share_cents=developer_share_cents,
|
||||
stripe_transfer_id=stripe_transfer_id,
|
||||
)
|
||||
db.add(event)
|
||||
await db.commit()
|
||||
|
||||
await registry.record_install(plugin_id)
|
||||
await registry.record_install(db, plugin_id)
|
||||
|
||||
async def get_earnings(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
developer_id: str,
|
||||
period: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
@@ -136,54 +136,81 @@ class RevenueShare:
|
||||
"developer_share_cents": int,
|
||||
}
|
||||
"""
|
||||
# Find plugin ids belonging to this developer
|
||||
developer_plugin_ids: set[str] = {
|
||||
pid
|
||||
for pid, entry in registry._catalog.items()
|
||||
if entry["manifest"].author == developer_id
|
||||
}
|
||||
# Find plugin ids belonging to this developer (by author_name match)
|
||||
plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id)
|
||||
plugin_result = await db.execute(plugin_q)
|
||||
developer_plugin_ids = [row[0] for row in plugin_result.all()]
|
||||
|
||||
events = [e for e in self._events if e["plugin_id"] in developer_plugin_ids]
|
||||
if not developer_plugin_ids:
|
||||
return {
|
||||
"developer_id": developer_id,
|
||||
"period": period,
|
||||
"total_installs": 0,
|
||||
"total_revenue_cents": 0,
|
||||
"developer_share_cents": 0,
|
||||
}
|
||||
|
||||
query = select(
|
||||
func.count().label("total_installs"),
|
||||
func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"),
|
||||
func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"),
|
||||
).where(RevenueEvent.plugin_id.in_(developer_plugin_ids))
|
||||
|
||||
if period:
|
||||
# Filter by YYYY-MM prefix of the created_at timestamp
|
||||
events = [
|
||||
e
|
||||
for e in events
|
||||
if time.strftime("%Y-%m", time.gmtime(e["created_at"])) == period
|
||||
]
|
||||
# Filter by YYYY-MM: extract year and month from created_at
|
||||
try:
|
||||
year, month = period.split("-")
|
||||
query = query.where(
|
||||
extract("year", RevenueEvent.created_at) == int(year),
|
||||
extract("month", RevenueEvent.created_at) == int(month),
|
||||
)
|
||||
except ValueError:
|
||||
pass # invalid period format — return all
|
||||
|
||||
result = await db.execute(query)
|
||||
row = result.one()
|
||||
|
||||
return {
|
||||
"developer_id": developer_id,
|
||||
"period": period,
|
||||
"total_installs": len(events),
|
||||
"total_revenue_cents": sum(e["amount_cents"] for e in events),
|
||||
"developer_share_cents": sum(e["developer_share_cents"] for e in events),
|
||||
"total_installs": row.total_installs,
|
||||
"total_revenue_cents": row.total_revenue,
|
||||
"developer_share_cents": row.dev_share,
|
||||
}
|
||||
|
||||
async def payout_developer(self, plugin_id: str, period: str) -> None:
|
||||
async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None:
|
||||
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
|
||||
|
||||
Marks processed events with ``paid_at`` timestamp.
|
||||
Stubs gracefully when Stripe is not configured.
|
||||
"""
|
||||
unpaid = [
|
||||
e
|
||||
for e in self._events
|
||||
if e["plugin_id"] == plugin_id
|
||||
and e["paid_at"] is None
|
||||
and time.strftime("%Y-%m", time.gmtime(e["created_at"])) == period
|
||||
]
|
||||
try:
|
||||
year, month = period.split("-")
|
||||
year_int, month_int = int(year), int(month)
|
||||
except ValueError:
|
||||
logger.warning("Invalid period format: %s", period)
|
||||
return
|
||||
|
||||
total_dev_share = sum(e["developer_share_cents"] for e in unpaid)
|
||||
result = await db.execute(
|
||||
select(RevenueEvent).where(
|
||||
RevenueEvent.plugin_id == plugin_id,
|
||||
RevenueEvent.paid_at.is_(None),
|
||||
extract("year", RevenueEvent.created_at) == year_int,
|
||||
extract("month", RevenueEvent.created_at) == month_int,
|
||||
)
|
||||
)
|
||||
unpaid = list(result.scalars().all())
|
||||
|
||||
total_dev_share = sum(e.developer_share_cents for e in unpaid)
|
||||
if total_dev_share <= 0 or not unpaid:
|
||||
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
|
||||
return
|
||||
|
||||
if self._stripe_configured():
|
||||
plugin_entry = registry._catalog.get(plugin_id)
|
||||
developer_stripe_account: str | None = None # Step 12: fetch from DB
|
||||
if plugin_entry and developer_stripe_account:
|
||||
plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
plugin_row = plugin_result.scalar_one_or_none()
|
||||
developer_stripe_account: str | None = None # Future: fetch from DB
|
||||
if plugin_row and developer_stripe_account:
|
||||
try:
|
||||
s = self._stripe()
|
||||
s.Transfer.create(
|
||||
@@ -196,9 +223,10 @@ class RevenueShare:
|
||||
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
|
||||
return
|
||||
|
||||
paid_ts = int(time.time())
|
||||
paid_ts = datetime.now(timezone.utc)
|
||||
for event in unpaid:
|
||||
event["paid_at"] = paid_ts
|
||||
event.paid_at = paid_ts
|
||||
await db.commit()
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
|
||||
@@ -32,9 +32,9 @@ from sqlalchemy import (
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
Uuid,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db import Base
|
||||
@@ -64,7 +64,7 @@ class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
@@ -89,10 +89,10 @@ class RefreshToken(Base):
|
||||
__tablename__ = "refresh_tokens"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
@@ -107,10 +107,10 @@ class Subscription(Base):
|
||||
__tablename__ = "subscriptions"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False, unique=True, index=True
|
||||
)
|
||||
stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
|
||||
@@ -128,10 +128,10 @@ class StorageRecord(Base):
|
||||
__tablename__ = "storage_records"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
@@ -149,10 +149,10 @@ class BackupMetadata(Base):
|
||||
__tablename__ = "backup_metadata"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
@@ -173,7 +173,7 @@ class Plugin(Base):
|
||||
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
|
||||
# nullable until developer account system is built
|
||||
author_id: Mapped[str | None] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
|
||||
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
|
||||
@@ -207,13 +207,13 @@ class PluginInstallation(Base):
|
||||
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
plugin_id: Mapped[str] = mapped_column(
|
||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
installed_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
@@ -226,13 +226,13 @@ class PluginReview(Base):
|
||||
__tablename__ = "plugin_reviews"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
plugin_id: Mapped[str] = mapped_column(
|
||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
reviewer_id: Mapped[str | None] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
|
||||
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
@@ -250,13 +250,13 @@ class RevenueEvent(Base):
|
||||
__tablename__ = "revenue_events"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), primary_key=True, default=_uuid
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
plugin_id: Mapped[str] = mapped_column(
|
||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
@@ -15,8 +15,10 @@ bcrypt>=4.2.0
|
||||
python-dotenv>=1.0.0
|
||||
httpx>=0.28.0
|
||||
websockets>=14.0
|
||||
psycopg2-binary>=2.9.0
|
||||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.24.0
|
||||
aiosqlite>=0.20.0
|
||||
moto[s3]>=5.0.0
|
||||
pinecone>=5.0.0
|
||||
qdrant-client>=1.7.0
|
||||
|
||||
208
tests/conftest.py
Normal file
208
tests/conftest.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Shared test fixtures for database-backed tests.
|
||||
|
||||
Provides an async SQLite in-memory engine that auto-creates all tables,
|
||||
a per-test session, and a FastAPI ``TestClient`` wired to use it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi.testclient import TestClient
|
||||
from jose import jwt
|
||||
from sqlalchemy import StaticPool, event
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.db import Base, get_session
|
||||
from app.main import app
|
||||
from app.models import Plugin, Subscription, User
|
||||
|
||||
# ── Fixed test user IDs (one per tier) ───────────────────────────────
|
||||
|
||||
TEST_USER_IDS: dict[str, str] = {
|
||||
"free": "00000000-0000-0000-0000-000000000001",
|
||||
"pro": "00000000-0000-0000-0000-000000000002",
|
||||
"power": "00000000-0000-0000-0000-000000000003",
|
||||
"team": "00000000-0000-0000-0000-000000000004",
|
||||
}
|
||||
|
||||
# ── Async SQLite engine ──────────────────────────────────────────────
|
||||
|
||||
_TEST_ENGINE = create_async_engine(
|
||||
"sqlite+aiosqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
|
||||
_TestSessionLocal = async_sessionmaker(
|
||||
_TEST_ENGINE,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
# Enable foreign key enforcement for SQLite (off by default).
|
||||
@event.listens_for(_TEST_ENGINE.sync_engine, "connect")
|
||||
def _set_sqlite_pragma(dbapi_conn, _connection_record): # noqa: ANN001
|
||||
cursor = dbapi_conn.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
|
||||
# ── Fixtures ─────────────────────────────────────────────────────────
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def _create_tables():
|
||||
"""Create all tables before each test, seed test users, then drop after."""
|
||||
async with _TEST_ENGINE.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
# Seed one User + Subscription per tier so FK constraints and auth work.
|
||||
async with _TestSessionLocal() as session:
|
||||
for tier, uid in TEST_USER_IDS.items():
|
||||
session.add(User(
|
||||
id=uid,
|
||||
email=f"{tier}@test.com",
|
||||
password_hash="$2b$12$fakehashfortesting000000000000000000000000000",
|
||||
tier=tier,
|
||||
))
|
||||
session.add(Subscription(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=uid,
|
||||
tier=tier,
|
||||
stripe_subscription_id=f"sub_test_{tier}",
|
||||
status="active",
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
yield
|
||||
async with _TEST_ENGINE.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Yield a per-test async DB session."""
|
||||
async with _TestSessionLocal() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(db_session: AsyncSession) -> Generator[TestClient, None, None]: # noqa: ANN001
|
||||
"""FastAPI test client with ``get_session`` overridden to use the test DB."""
|
||||
|
||||
async def _override_get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _override_get_session
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
# ── Seed data helpers ────────────────────────────────────────────────
|
||||
|
||||
_SEED_PLUGINS = [
|
||||
Plugin(
|
||||
id="plugin-github-sync",
|
||||
name="GitHub Sync",
|
||||
description="Sync tasks with GitHub Issues and pull requests.",
|
||||
version="1.0.0",
|
||||
author_name="Adiuva",
|
||||
category="productivity",
|
||||
price_cents=0,
|
||||
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
||||
status="approved",
|
||||
s3_package_key="plugins/plugin-github-sync/1.0.0/package.zip",
|
||||
install_count=0,
|
||||
avg_rating=0.0,
|
||||
),
|
||||
Plugin(
|
||||
id="plugin-slack-notify",
|
||||
name="Slack Notifier",
|
||||
description="Post task and checkpoint updates to Slack channels.",
|
||||
version="1.2.0",
|
||||
author_name="Adiuva",
|
||||
category="communication",
|
||||
price_cents=499,
|
||||
permissions=json.dumps(["read:tasks", "read:checkpoints"]),
|
||||
status="approved",
|
||||
s3_package_key="plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||
install_count=0,
|
||||
avg_rating=0.0,
|
||||
),
|
||||
Plugin(
|
||||
id="plugin-time-tracker",
|
||||
name="Time Tracker",
|
||||
description="Track time spent on tasks with automatic reporting.",
|
||||
version="0.9.1",
|
||||
author_name="Third Party",
|
||||
category="productivity",
|
||||
price_cents=999,
|
||||
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
||||
status="approved",
|
||||
s3_package_key="plugins/plugin-time-tracker/0.9.1/package.zip",
|
||||
install_count=0,
|
||||
avg_rating=0.0,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def seed_plugins(db_session: AsyncSession) -> list[Plugin]:
|
||||
"""Insert the 3 default approved plugins and return them."""
|
||||
plugins = []
|
||||
for template in _SEED_PLUGINS:
|
||||
p = Plugin(
|
||||
id=template.id,
|
||||
name=template.name,
|
||||
description=template.description,
|
||||
version=template.version,
|
||||
author_name=template.author_name,
|
||||
category=template.category,
|
||||
price_cents=template.price_cents,
|
||||
permissions=template.permissions,
|
||||
status=template.status,
|
||||
s3_package_key=template.s3_package_key,
|
||||
install_count=template.install_count,
|
||||
avg_rating=template.avg_rating,
|
||||
)
|
||||
db_session.add(p)
|
||||
plugins.append(p)
|
||||
await db_session.commit()
|
||||
return plugins
|
||||
|
||||
|
||||
# ── JWT helpers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def make_jwt(
|
||||
tier: str = "power",
|
||||
user_id: str | None = None,
|
||||
email: str | None = None,
|
||||
) -> str:
|
||||
"""Create a signed test JWT.
|
||||
|
||||
Uses the fixed ``TEST_USER_IDS`` mapping so the auth middleware can
|
||||
find the corresponding ``Subscription`` row in the test database.
|
||||
"""
|
||||
uid = user_id or TEST_USER_IDS.get(tier, str(uuid.uuid4()))
|
||||
now = int(time.time())
|
||||
payload = {
|
||||
"sub": uid,
|
||||
"email": email or f"{tier}@test.com",
|
||||
"tier": tier,
|
||||
"exp": now + 3600,
|
||||
"iat": now,
|
||||
}
|
||||
return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, str]:
|
||||
"""Return an Authorization header dict for the given tier."""
|
||||
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
||||
@@ -18,13 +18,30 @@ from fastapi.testclient import TestClient
|
||||
from jose import jwt
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.schemas import ChatResponse
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Autouse: redirect all DB access to the in-memory SQLite test engine.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _override_db(db_session):
|
||||
"""Route all get_session calls to the test SQLite session."""
|
||||
async def _gen():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _gen
|
||||
yield
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
_CHAT_BODY = {
|
||||
"message": "hello",
|
||||
"context": {
|
||||
@@ -74,14 +91,15 @@ class TestAuthMiddleware:
|
||||
"""Tests exercised via GET /api/v1/auth/me."""
|
||||
|
||||
def test_valid_token_returns_profile(self) -> None:
|
||||
uid = str(uuid.uuid4())
|
||||
token = _make_jwt(user_id=uid, email="alice@example.com", tier="pro")
|
||||
# Use the seeded pro user so the subscription lookup returns 'pro'.
|
||||
uid = TEST_USER_IDS["pro"]
|
||||
token = _make_jwt(user_id=uid, email="pro@test.com", tier="pro")
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == uid
|
||||
assert data["email"] == "alice@example.com"
|
||||
assert data["email"] == "pro@test.com"
|
||||
assert data["tier"] == "pro"
|
||||
|
||||
def test_missing_token_returns_401(self) -> None:
|
||||
|
||||
@@ -1,52 +1,34 @@
|
||||
"""Tests for Step 10: Plugin Marketplace.
|
||||
"""Tests for Step 10+12: Plugin Marketplace (DB-backed).
|
||||
|
||||
Covers:
|
||||
- PluginRegistry: catalog management, filtering, sorting, install counts
|
||||
- PluginRegistry: catalog management, filtering, sorting, install counts (PostgreSQL)
|
||||
- ReviewQueue: pending queue, review decisions, manifest security checklist
|
||||
- RevenueShare: install event recording, earnings aggregation
|
||||
- RevenueShare: install event recording, earnings aggregation (PostgreSQL)
|
||||
- Route integration: tier gate, list/get/install/uninstall via TestClient
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi.testclient import TestClient
|
||||
from jose import jwt
|
||||
from unittest.mock import patch
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.main import app
|
||||
from app.marketplace.plugin_registry import PluginRegistry
|
||||
from app.marketplace.plugin_review import ReviewQueue, validate_manifest
|
||||
from app.marketplace.revenue_share import RevenueShare
|
||||
from app.models import Plugin, PluginReview as PluginReviewModel, RevenueEvent
|
||||
from app.schemas import PluginManifest
|
||||
from tests.conftest import TEST_USER_IDS, auth_header
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_jwt(tier: str = "power", user_id: str | None = None) -> str:
|
||||
uid = user_id or str(uuid.uuid4())
|
||||
now = int(time.time())
|
||||
payload = {
|
||||
"sub": uid,
|
||||
"email": f"{uid[:8]}@example.com",
|
||||
"tier": tier,
|
||||
"exp": now + 3600,
|
||||
"iat": now,
|
||||
}
|
||||
return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||
|
||||
|
||||
def _auth(tier: str = "power") -> dict[str, str]:
|
||||
return {"Authorization": f"Bearer {_make_jwt(tier)}"}
|
||||
|
||||
|
||||
def _fresh_manifest(
|
||||
plugin_id: str | None = None,
|
||||
category: str = "productivity",
|
||||
@@ -67,118 +49,150 @@ def _fresh_manifest(
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PluginRegistry
|
||||
# PluginRegistry (DB-backed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPluginRegistry:
|
||||
"""Each test uses a fresh PluginRegistry instance to avoid catalog pollution."""
|
||||
"""Each test uses the conftest db_session fixture with a fresh in-memory DB."""
|
||||
|
||||
@pytest.fixture
|
||||
def reg(self) -> PluginRegistry:
|
||||
return PluginRegistry()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_seed_plugins_are_approved(self, reg: PluginRegistry) -> None:
|
||||
result = await reg.list_plugins()
|
||||
async def test_seed_plugins_are_listed(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
result = await reg.list_plugins(db_session)
|
||||
assert result.total == 3
|
||||
assert all(p.id.startswith("plugin-") for p in result.plugins)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_approved_only(self, reg: PluginRegistry) -> None:
|
||||
async def test_list_approved_only(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(manifest, "plugins/key.zip")
|
||||
result = await reg.list_plugins()
|
||||
await reg.submit_plugin(db_session, manifest, "plugins/key.zip")
|
||||
result = await reg.list_plugins(db_session)
|
||||
ids = [p.id for p in result.plugins]
|
||||
assert manifest.id not in ids # still pending
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_filter_by_category(self, reg: PluginRegistry) -> None:
|
||||
result = await reg.list_plugins(category="communication")
|
||||
async def test_list_filter_by_category(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
result = await reg.list_plugins(db_session, category="communication")
|
||||
assert result.total == 1
|
||||
assert result.plugins[0].id == "plugin-slack-notify"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_filter_by_query(self, reg: PluginRegistry) -> None:
|
||||
result = await reg.list_plugins(query="time")
|
||||
async def test_list_filter_by_query(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
result = await reg.list_plugins(db_session, query="time")
|
||||
assert result.total == 1
|
||||
assert result.plugins[0].id == "plugin-time-tracker"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sort_by_installs(self, reg: PluginRegistry) -> None:
|
||||
await reg.record_install("plugin-slack-notify")
|
||||
await reg.record_install("plugin-slack-notify")
|
||||
result = await reg.list_plugins(sort="installs")
|
||||
async def test_list_sort_by_installs(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
await reg.record_install(db_session, "plugin-slack-notify")
|
||||
await reg.record_install(db_session, "plugin-slack-notify")
|
||||
result = await reg.list_plugins(db_session, sort="installs")
|
||||
assert result.plugins[0].id == "plugin-slack-notify"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_plugin_found(self, reg: PluginRegistry) -> None:
|
||||
entry = await reg.get_plugin("plugin-github-sync")
|
||||
async def test_get_plugin_found(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||
assert entry is not None
|
||||
assert entry["manifest"].id == "plugin-github-sync"
|
||||
assert "install_count" in entry
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_plugin_not_found(self, reg: PluginRegistry) -> None:
|
||||
entry = await reg.get_plugin("no-such-plugin")
|
||||
async def test_get_plugin_not_found(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession
|
||||
) -> None:
|
||||
entry = await reg.get_plugin(db_session, "no-such-plugin")
|
||||
assert entry is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_sets_pending(self, reg: PluginRegistry) -> None:
|
||||
async def test_submit_sets_pending(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession
|
||||
) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
plugin_id = await reg.submit_plugin(manifest, "key.zip")
|
||||
plugin_id = await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||
assert plugin_id == manifest.id
|
||||
assert reg._catalog[plugin_id]["status"] == "pending_review"
|
||||
result = await db_session.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||
row = result.scalar_one()
|
||||
assert row.status == "pending_review"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_makes_visible(self, reg: PluginRegistry) -> None:
|
||||
async def test_approve_makes_visible(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession
|
||||
) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(manifest, "key.zip")
|
||||
await reg.approve_plugin(manifest.id)
|
||||
result = await reg.list_plugins()
|
||||
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||
await reg.approve_plugin(db_session, manifest.id)
|
||||
result = await reg.list_plugins(db_session)
|
||||
assert manifest.id in [p.id for p in result.plugins]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reject_stores_reason(self, reg: PluginRegistry) -> None:
|
||||
async def test_reject_stores_reason(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession
|
||||
) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(manifest, "key.zip")
|
||||
await reg.reject_plugin(manifest.id, reason="Unsafe permissions")
|
||||
assert reg._catalog[manifest.id]["status"] == "rejected"
|
||||
assert reg._catalog[manifest.id]["rejection_reason"] == "Unsafe permissions"
|
||||
result = await reg.list_plugins()
|
||||
assert manifest.id not in [p.id for p in result.plugins]
|
||||
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||
await reg.reject_plugin(db_session, manifest.id, reason="Unsafe permissions")
|
||||
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
|
||||
row = result.scalar_one()
|
||||
assert row.status == "rejected"
|
||||
assert row.rejection_reason == "Unsafe permissions"
|
||||
listed = await reg.list_plugins(db_session)
|
||||
assert manifest.id not in [p.id for p in listed.plugins]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approve_unknown_raises_key_error(self, reg: PluginRegistry) -> None:
|
||||
async def test_approve_unknown_raises_key_error(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession
|
||||
) -> None:
|
||||
with pytest.raises(KeyError):
|
||||
await reg.approve_plugin("ghost-plugin")
|
||||
await reg.approve_plugin(db_session, "ghost-plugin")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_install_increments_count(self, reg: PluginRegistry) -> None:
|
||||
await reg.record_install("plugin-github-sync")
|
||||
entry = await reg.get_plugin("plugin-github-sync")
|
||||
async def test_record_install_increments_count(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
await reg.record_install(db_session, "plugin-github-sync")
|
||||
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||
assert entry is not None
|
||||
assert entry["install_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_uninstall_decrements_count(self, reg: PluginRegistry) -> None:
|
||||
await reg.record_install("plugin-github-sync")
|
||||
await reg.record_install("plugin-github-sync")
|
||||
await reg.record_uninstall("plugin-github-sync")
|
||||
entry = await reg.get_plugin("plugin-github-sync")
|
||||
async def test_record_uninstall_decrements_count(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
await reg.record_install(db_session, "plugin-github-sync")
|
||||
await reg.record_install(db_session, "plugin-github-sync")
|
||||
await reg.record_uninstall(db_session, "plugin-github-sync")
|
||||
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||
assert entry is not None
|
||||
assert entry["install_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_uninstall_floors_at_zero(self, reg: PluginRegistry) -> None:
|
||||
await reg.record_uninstall("plugin-github-sync") # already 0
|
||||
entry = await reg.get_plugin("plugin-github-sync")
|
||||
async def test_record_uninstall_floors_at_zero(
|
||||
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
await reg.record_uninstall(db_session, "plugin-github-sync")
|
||||
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||
assert entry is not None
|
||||
assert entry["install_count"] == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ReviewQueue
|
||||
# ReviewQueue (DB-backed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -188,37 +202,47 @@ class TestReviewQueue:
|
||||
return PluginRegistry()
|
||||
|
||||
@pytest.fixture
|
||||
def queue(self, reg: PluginRegistry) -> ReviewQueue:
|
||||
# Patch the 'registry' name as bound inside plugin_review.py
|
||||
with patch("app.marketplace.plugin_review.registry", reg):
|
||||
yield ReviewQueue()
|
||||
def queue(self) -> ReviewQueue:
|
||||
return ReviewQueue()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pending_returns_submitted_plugins(
|
||||
self, reg: PluginRegistry, queue: ReviewQueue
|
||||
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
|
||||
) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(manifest, "key.zip")
|
||||
pending = await queue.get_pending()
|
||||
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||
pending = await queue.get_pending(db_session)
|
||||
assert any(p["plugin_id"] == manifest.id for p in pending)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_review_approved(
|
||||
self, reg: PluginRegistry, queue: ReviewQueue
|
||||
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
|
||||
) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(manifest, "key.zip")
|
||||
await queue.submit_review(manifest.id, "reviewer-1", "approved", "Looks good")
|
||||
assert reg._catalog[manifest.id]["status"] == "approved"
|
||||
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||
await queue.submit_review(db_session, manifest.id, TEST_USER_IDS["power"], "approved", "Looks good")
|
||||
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
|
||||
row = result.scalar_one()
|
||||
assert row.status == "approved"
|
||||
# Check review row was persisted
|
||||
review_result = await db_session.execute(
|
||||
select(PluginReviewModel).where(PluginReviewModel.plugin_id == manifest.id)
|
||||
)
|
||||
review = review_result.scalar_one()
|
||||
assert review.decision == "approved"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_submit_review_rejected(
|
||||
self, reg: PluginRegistry, queue: ReviewQueue
|
||||
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
|
||||
) -> None:
|
||||
manifest = _fresh_manifest()
|
||||
await reg.submit_plugin(manifest, "key.zip")
|
||||
await queue.submit_review(manifest.id, "reviewer-1", "rejected", "Bad permissions")
|
||||
assert reg._catalog[manifest.id]["status"] == "rejected"
|
||||
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||
await queue.submit_review(
|
||||
db_session, manifest.id, TEST_USER_IDS["power"], "rejected", "Bad permissions"
|
||||
)
|
||||
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
|
||||
row = result.scalar_one()
|
||||
assert row.status == "rejected"
|
||||
|
||||
def test_validate_manifest_ok(self) -> None:
|
||||
manifest = _fresh_manifest(permissions=["read:tasks", "write:notes"])
|
||||
@@ -241,65 +265,66 @@ class TestReviewQueue:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RevenueShare
|
||||
# RevenueShare (DB-backed)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRevenueShare:
|
||||
@pytest.fixture
|
||||
def reg(self) -> PluginRegistry:
|
||||
return PluginRegistry()
|
||||
|
||||
@pytest.fixture
|
||||
def rs(self, reg: PluginRegistry) -> RevenueShare:
|
||||
# Patch the 'registry' name as bound inside revenue_share.py
|
||||
with patch("app.marketplace.revenue_share.registry", reg):
|
||||
yield RevenueShare()
|
||||
def rs(self) -> RevenueShare:
|
||||
return RevenueShare()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_install_free_plugin(
|
||||
self, reg: PluginRegistry, rs: RevenueShare
|
||||
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
await rs.record_install("plugin-github-sync", "user-1", amount_cents=0)
|
||||
assert len(rs._events) == 1
|
||||
assert rs._events[0]["developer_share_cents"] == 0
|
||||
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
|
||||
result = await db_session.execute(
|
||||
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-github-sync")
|
||||
)
|
||||
event = result.scalar_one()
|
||||
assert event.developer_share_cents == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_install_paid_plugin_no_stripe(
|
||||
self, reg: PluginRegistry, rs: RevenueShare
|
||||
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
# No STRIPE_SECRET_KEY configured in test env — should not crash
|
||||
await rs.record_install("plugin-slack-notify", "user-2", amount_cents=499)
|
||||
assert len(rs._events) == 1
|
||||
assert rs._events[0]["amount_cents"] == 499
|
||||
assert rs._events[0]["developer_share_cents"] == int(499 * 0.70)
|
||||
await rs.record_install(
|
||||
db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499
|
||||
)
|
||||
result = await db_session.execute(
|
||||
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-slack-notify")
|
||||
)
|
||||
event = result.scalar_one()
|
||||
assert event.amount_cents == 499
|
||||
assert event.developer_share_cents == int(499 * 0.70)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_install_increments_registry_count(
|
||||
self, reg: PluginRegistry, rs: RevenueShare
|
||||
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
await rs.record_install("plugin-github-sync", "user-1", amount_cents=0)
|
||||
entry = await reg.get_plugin("plugin-github-sync")
|
||||
reg = PluginRegistry()
|
||||
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
|
||||
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||
assert entry is not None
|
||||
assert entry["install_count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_earnings_empty(
|
||||
self, reg: PluginRegistry, rs: RevenueShare
|
||||
self, rs: RevenueShare, db_session: AsyncSession
|
||||
) -> None:
|
||||
result = await rs.get_earnings("unknown-dev")
|
||||
result = await rs.get_earnings(db_session, "unknown-dev")
|
||||
assert result["total_installs"] == 0
|
||||
assert result["total_revenue_cents"] == 0
|
||||
assert result["developer_share_cents"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_earnings_aggregates(
|
||||
self, reg: PluginRegistry, rs: RevenueShare
|
||||
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||
) -> None:
|
||||
# "Adiuva" is the author of the seeded plugins
|
||||
await rs.record_install("plugin-slack-notify", "u1", amount_cents=499)
|
||||
await rs.record_install("plugin-slack-notify", "u2", amount_cents=499)
|
||||
result = await rs.get_earnings("Adiuva")
|
||||
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["power"], amount_cents=499)
|
||||
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499)
|
||||
result = await rs.get_earnings(db_session, "Adiuva")
|
||||
assert result["total_installs"] == 2
|
||||
assert result["total_revenue_cents"] == 998
|
||||
assert result["developer_share_cents"] == int(499 * 0.70) * 2
|
||||
@@ -311,77 +336,67 @@ class TestRevenueShare:
|
||||
|
||||
|
||||
class TestPluginRoutes:
|
||||
def test_list_plugins_requires_power_tier(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/plugins", headers=_auth("free"))
|
||||
def test_list_plugins_requires_power_tier(self, client, seed_plugins) -> None:
|
||||
resp = client.get("/api/v1/plugins", headers=auth_header("free"))
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_list_plugins_pro_tier_blocked(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/plugins", headers=_auth("pro"))
|
||||
def test_list_plugins_pro_tier_blocked(self, client, seed_plugins) -> None:
|
||||
resp = client.get("/api/v1/plugins", headers=auth_header("pro"))
|
||||
assert resp.status_code == 403
|
||||
|
||||
def test_list_plugins_power_tier_ok(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/plugins", headers=_auth("power"))
|
||||
def test_list_plugins_power_tier_ok(self, client, seed_plugins) -> None:
|
||||
resp = client.get("/api/v1/plugins", headers=auth_header("power"))
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "plugins" in data
|
||||
assert data["total"] >= 3
|
||||
assert data["total"] == 3
|
||||
|
||||
def test_list_plugins_team_tier_ok(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/plugins", headers=_auth("team"))
|
||||
def test_list_plugins_team_tier_ok(self, client, seed_plugins) -> None:
|
||||
resp = client.get("/api/v1/plugins", headers=auth_header("team"))
|
||||
assert resp.status_code == 200
|
||||
|
||||
def test_get_plugin_found(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/plugins/plugin-github-sync", headers=_auth())
|
||||
def test_get_plugin_found(self, client, seed_plugins) -> None:
|
||||
resp = client.get("/api/v1/plugins/plugin-github-sync", headers=auth_header())
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["plugin"]["id"] == "plugin-github-sync"
|
||||
assert "install_count" in data
|
||||
|
||||
def test_get_plugin_not_found(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/api/v1/plugins/no-such-plugin", headers=_auth())
|
||||
def test_get_plugin_not_found(self, client, seed_plugins) -> None:
|
||||
resp = client.get("/api/v1/plugins/no-such-plugin", headers=auth_header())
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_install_plugin_free(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.post(
|
||||
"/api/v1/plugins/plugin-github-sync/install",
|
||||
json={"plugin_id": "plugin-github-sync"},
|
||||
headers=_auth(),
|
||||
)
|
||||
def test_install_plugin_free(self, client, seed_plugins) -> None:
|
||||
resp = client.post(
|
||||
"/api/v1/plugins/plugin-github-sync/install",
|
||||
json={"plugin_id": "plugin-github-sync"},
|
||||
headers=auth_header(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is True
|
||||
assert "download_url" in data
|
||||
|
||||
def test_install_plugin_not_found(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.post(
|
||||
"/api/v1/plugins/ghost/install",
|
||||
json={"plugin_id": "ghost"},
|
||||
headers=_auth(),
|
||||
)
|
||||
def test_install_plugin_not_found(self, client, seed_plugins) -> None:
|
||||
resp = client.post(
|
||||
"/api/v1/plugins/ghost/install",
|
||||
json={"plugin_id": "ghost"},
|
||||
headers=auth_header(),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_uninstall_plugin_ok(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.delete(
|
||||
"/api/v1/plugins/plugin-github-sync/install",
|
||||
headers=_auth(),
|
||||
)
|
||||
def test_uninstall_plugin_ok(self, client, seed_plugins) -> None:
|
||||
resp = client.delete(
|
||||
"/api/v1/plugins/plugin-github-sync/install",
|
||||
headers=auth_header(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["ok"] is True
|
||||
|
||||
def test_install_requires_power_tier(self) -> None:
|
||||
with TestClient(app) as client:
|
||||
resp = client.post(
|
||||
"/api/v1/plugins/plugin-github-sync/install",
|
||||
json={"plugin_id": "plugin-github-sync"},
|
||||
headers=_auth("free"),
|
||||
)
|
||||
def test_install_requires_power_tier(self, client, seed_plugins) -> None:
|
||||
resp = client.post(
|
||||
"/api/v1/plugins/plugin-github-sync/install",
|
||||
json={"plugin_id": "plugin-github-sync"},
|
||||
headers=auth_header("free"),
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
Reference in New Issue
Block a user