From d0b303e745c3e5dbe1f6f1a51350fd99ab510aaa Mon Sep 17 00:00:00 2001 From: roberto Date: Tue, 3 Mar 2026 14:53:34 +0100 Subject: [PATCH] Step 12 - completed --- BACKEND_PLAN.md | 6 +- alembic/versions/002_seed_plugins.py | 92 ++++++++ app/api/routes/backup.py | 113 +++++---- app/api/routes/plugins.py | 60 ++++- app/api/routes/storage.py | 132 ++++++----- app/marketplace/plugin_registry.py | 253 ++++++++++---------- app/marketplace/plugin_review.py | 38 ++- app/marketplace/revenue_share.py | 134 ++++++----- app/models.py | 34 +-- requirements.txt | 2 + tests/conftest.py | 208 ++++++++++++++++ tests/test_middleware.py | 24 +- tests/test_plugins.py | 341 ++++++++++++++------------- 13 files changed, 950 insertions(+), 487 deletions(-) create mode 100644 alembic/versions/002_seed_plugins.py create mode 100644 tests/conftest.py diff --git a/BACKEND_PLAN.md b/BACKEND_PLAN.md index b450f98..bc37989 100644 --- a/BACKEND_PLAN.md +++ b/BACKEND_PLAN.md @@ -439,7 +439,7 @@ adiuva-api/ - **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat). ### Step 12 — Database (auth/billing/marketplace only) -- [ ] PostgreSQL schema via Alembic: +- [x] PostgreSQL schema via Alembic: - `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at` - `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at` - `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at` @@ -449,8 +449,8 @@ adiuva-api/ - `plugin_installations`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `installed_at` - `plugin_reviews`: `id UUID PK`, `plugin_id FK`, `reviewer_id FK`, `decision`, `notes`, `reviewed_at` - `revenue_events`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `amount_cents`, `developer_share_cents`, `stripe_transfer_id`, `created_at` -- [ ] Initial Alembic migration -- [ ] SQLAlchemy models in `app/models.py` +- [x] Initial Alembic migration +- [x] SQLAlchemy models in `app/models.py` - **Outcome:** Auth, billing, storage metadata, and marketplace persistence. Zero user data in plaintext. ### Step 13 — Testing & deployment diff --git a/alembic/versions/002_seed_plugins.py b/alembic/versions/002_seed_plugins.py new file mode 100644 index 0000000..0fad36a --- /dev/null +++ b/alembic/versions/002_seed_plugins.py @@ -0,0 +1,92 @@ +"""Seed approved plugins: GitHub Sync, Slack Notifier, Time Tracker. + +Revision ID: 002 +Revises: 001 +Create Date: 2026-03-03 +""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +revision: str = "002" +down_revision: Union[str, None] = "001" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +_SEED_PLUGINS = [ + { + "id": "plugin-github-sync", + "name": "GitHub Sync", + "description": "Sync tasks with GitHub Issues and pull requests.", + "version": "1.0.0", + "author_name": "Adiuva", + "category": "productivity", + "price_cents": 0, + "permissions": json.dumps(["read:tasks", "write:tasks"]), + "status": "approved", + "s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip", + "install_count": 0, + "avg_rating": 0.0, + }, + { + "id": "plugin-slack-notify", + "name": "Slack Notifier", + "description": "Post task and checkpoint updates to Slack channels.", + "version": "1.2.0", + "author_name": "Adiuva", + "category": "communication", + "price_cents": 499, + "permissions": json.dumps(["read:tasks", "read:checkpoints"]), + "status": "approved", + "s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip", + "install_count": 0, + "avg_rating": 0.0, + }, + { + "id": "plugin-time-tracker", + "name": "Time Tracker", + "description": "Track time spent on tasks with automatic reporting.", + "version": "0.9.1", + "author_name": "Third Party", + "category": "productivity", + "price_cents": 999, + "permissions": json.dumps(["read:tasks", "write:tasks"]), + "status": "approved", + "s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip", + "install_count": 0, + "avg_rating": 0.0, + }, +] + + +def upgrade() -> None: + plugins = sa.table( + "plugins", + sa.column("id", sa.String), + sa.column("name", sa.String), + sa.column("description", sa.Text), + sa.column("version", sa.String), + sa.column("author_name", sa.String), + sa.column("category", sa.String), + sa.column("price_cents", sa.Integer), + sa.column("permissions", sa.Text), + sa.column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status")), + sa.column("s3_package_key", sa.String), + sa.column("install_count", sa.Integer), + sa.column("avg_rating", sa.Float), + ) + op.bulk_insert(plugins, _SEED_PLUGINS) + + +def downgrade() -> None: + op.execute( + "DELETE FROM plugins WHERE id IN (" + "'plugin-github-sync', 'plugin-slack-notify', 'plugin-time-tracker'" + ")" + ) diff --git a/app/api/routes/backup.py b/app/api/routes/backup.py index bb8821a..2b8eeae 100644 --- a/app/api/routes/backup.py +++ b/app/api/routes/backup.py @@ -1,7 +1,7 @@ """Backup routes: upload, download, history, and delete E2E-encrypted backups. -Blobs are stored in S3 via BlobStore. Backup metadata is kept in an -in-memory dict until Step 12 migrates it to PostgreSQL (backup_metadata table). +Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the +PostgreSQL ``backup_metadata`` table. IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI treating "history" as a ``{backup_id}`` path parameter. @@ -9,14 +9,17 @@ treating "history" as a ``{backup_id}`` path parameter. from __future__ import annotations -import time +import uuid from email.utils import parsedate_to_datetime -from typing import Any from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user from app.billing.tier_manager import tier_manager +from app.db import get_session +from app.models import BackupMetadata as BackupMetadataModel from app.schemas import BackupMetadata, UserProfile from app.storage.blob_store import BlobStore from app.storage.encryption import reject_if_tampered @@ -25,14 +28,25 @@ router = APIRouter(prefix="/backup", tags=["backup"]) _blob_store = BlobStore() -# In-memory backup metadata — replaced by PostgreSQL backup_metadata table in Step 12 -_backups: dict[str, list[dict[str, Any]]] = {} # user_id → list of backup records + +async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int: + """Return total backup bytes stored by *user_id*.""" + result = await db.execute( + select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where( + BackupMetadataModel.user_id == user_id + ) + ) + return int(result.scalar_one()) -def _check_backup_quota(user_id: str, size_bytes: int) -> None: +async def _check_backup_quota( + user: UserProfile, size_bytes: int, db: AsyncSession +) -> None: """Raise HTTP 402 if the upload would exceed the tier's backup limit.""" - current = sum(b["size_bytes"] for b in _backups.get(user_id, [])) - tier_manager.enforce_backup_quota(user_id, current_bytes=current, additional_bytes=size_bytes) + current = await _current_backup_bytes(user.id, db) + tier_manager.enforce_backup_quota( + user.tier, current_bytes=current, additional_bytes=size_bytes + ) @router.put("") @@ -42,6 +56,7 @@ async def upload_backup( x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"), x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"), current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), ) -> dict[str, bool]: """Upload an E2E-encrypted backup blob. @@ -49,24 +64,23 @@ async def upload_backup( """ blob = await request.body() reject_if_tampered(blob, x_backup_checksum) - _check_backup_quota(current_user.id, len(blob)) + await _check_backup_quota(current_user, len(blob), db) s3_key = await _blob_store.upload( current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum ) - backup_record: dict[str, Any] = { - "id": str(x_backup_timestamp), - "s3_key": s3_key, - "version": x_backup_version, - "timestamp": x_backup_timestamp, - "checksum": x_backup_checksum, - "size_bytes": len(blob), - } - - user_backups = _backups.setdefault(current_user.id, []) - user_backups.append(backup_record) - user_backups.sort(key=lambda b: b["timestamp"], reverse=True) + row = BackupMetadataModel( + id=str(uuid.uuid4()), + user_id=current_user.id, + s3_key=s3_key, + version=x_backup_version, + timestamp=x_backup_timestamp, + checksum=x_backup_checksum, + size_bytes=len(blob), + ) + db.add(row) + await db.commit() return {"ok": True} @@ -74,16 +88,23 @@ async def upload_backup( @router.get("/history", response_model=list[BackupMetadata]) async def backup_history( current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), ) -> list[BackupMetadata]: """Return backup metadata records for the authenticated user (no blob bytes).""" + result = await db.execute( + select(BackupMetadataModel) + .where(BackupMetadataModel.user_id == current_user.id) + .order_by(BackupMetadataModel.timestamp.desc()) + ) + rows = result.scalars().all() return [ BackupMetadata( - version=b["version"], - timestamp=b["timestamp"], - checksum=b["checksum"], - chunk_count=1, # single-chunk uploads for now — TODO(Step12): track real count + version=r.version, + timestamp=r.timestamp, + checksum=r.checksum, + chunk_count=1, ) - for b in _backups.get(current_user.id, []) + for r in rows ] @@ -91,32 +112,37 @@ async def backup_history( async def download_backup( request: Request, current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), ) -> Response: """Download the latest backup blob. Supports ``If-Modified-Since``.""" - user_backups = _backups.get(current_user.id, []) - if not user_backups: + result = await db.execute( + select(BackupMetadataModel) + .where(BackupMetadataModel.user_id == current_user.id) + .order_by(BackupMetadataModel.timestamp.desc()) + .limit(1) + ) + latest = result.scalar_one_or_none() + if latest is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found") - latest = user_backups[0] - ims_header = request.headers.get("If-Modified-Since") if ims_header: try: ims_dt = parsedate_to_datetime(ims_header) ims_ms = int(ims_dt.timestamp() * 1000) - if latest["timestamp"] <= ims_ms: + if latest.timestamp <= ims_ms: return Response(status_code=status.HTTP_304_NOT_MODIFIED) except Exception: pass # malformed header — ignore and serve the blob - blob = await _blob_store.download(current_user.id, latest["s3_key"]) + blob = await _blob_store.download(current_user.id, latest.s3_key) return Response( content=blob, media_type="application/octet-stream", headers={ - "X-Backup-Version": str(latest["version"]), - "X-Backup-Timestamp": str(latest["timestamp"]), - "X-Checksum": latest["checksum"], + "X-Backup-Version": str(latest.version), + "X-Backup-Timestamp": str(latest.timestamp), + "X-Checksum": latest.checksum, }, ) @@ -125,14 +151,21 @@ async def download_backup( async def delete_backup( backup_id: str, current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), ) -> dict[str, bool]: """Delete a specific backup by ID.""" - user_backups = _backups.get(current_user.id, []) - target = next((b for b in user_backups if b["id"] == backup_id), None) + result = await db.execute( + select(BackupMetadataModel).where( + BackupMetadataModel.id == backup_id, + BackupMetadataModel.user_id == current_user.id, + ) + ) + target = result.scalar_one_or_none() if target is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found") - await _blob_store.delete(current_user.id, target["s3_key"]) - _backups[current_user.id] = [b for b in user_backups if b["id"] != backup_id] + await _blob_store.delete(current_user.id, target.s3_key) + await db.delete(target) + await db.commit() return {"ok": True} diff --git a/app/api/routes/plugins.py b/app/api/routes/plugins.py index 899612e..f3a2e6e 100644 --- a/app/api/routes/plugins.py +++ b/app/api/routes/plugins.py @@ -1,8 +1,7 @@ """Plugins routes: browse and install plugins from the marketplace. -Backed by ``PluginRegistry`` and ``RevenueShare`` service classes introduced -in Step 10. Step 12 will swap those services' in-memory stores for -PostgreSQL persistence. +Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that +persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables. """ from __future__ import annotations @@ -11,10 +10,14 @@ from typing import Any, Literal from fastapi import APIRouter, Depends, HTTPException, Query, status from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user +from app.db import get_session from app.marketplace.plugin_registry import registry from app.marketplace.revenue_share import revenue_share +from app.models import PluginInstallation, PluginReview as PluginReviewModel from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile router = APIRouter(prefix="/plugins", tags=["plugins"]) @@ -36,7 +39,7 @@ def _require_plugin_tier(user: UserProfile) -> None: class _PluginDetail(BaseModel): plugin: PluginManifest install_count: int - ratings: list[Any] # Step 12 populates from plugin_reviews table + ratings: list[Any] # ── Routes ──────────────────────────────────────────────────────────── @@ -48,26 +51,44 @@ async def list_plugins( page: int = Query(default=1, ge=1), sort: Literal["rating", "installs", "newest"] = Query(default="newest"), current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), ) -> PluginListResponse: """Browse the plugin marketplace. Requires Power tier or above.""" _require_plugin_tier(current_user) - return await registry.list_plugins(category=category, query=q, page=page, sort=sort) + return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort) @router.get("/{plugin_id}", response_model=_PluginDetail) async def get_plugin( plugin_id: str, current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), ) -> _PluginDetail: """Get full plugin details including install count. Requires Power tier or above.""" _require_plugin_tier(current_user) - entry = await registry.get_plugin(plugin_id) + entry = await registry.get_plugin(db, plugin_id) if entry is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found") + + # Fetch review ratings for this plugin + review_result = await db.execute( + select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id) + ) + reviews = review_result.scalars().all() + ratings = [ + { + "reviewer_id": r.reviewer_id, + "decision": r.decision, + "notes": r.notes, + "reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None, + } + for r in reviews + ] + return _PluginDetail( plugin=entry["manifest"], install_count=entry["install_count"], - ratings=[], # Step 12 populates from plugin_reviews table + ratings=ratings, ) @@ -76,17 +97,27 @@ async def install_plugin( plugin_id: str, body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), ) -> dict[str, Any]: """Install a plugin. Triggers Stripe Connect revenue split for paid plugins. Requires Power tier or above. """ _require_plugin_tier(current_user) - entry = await registry.get_plugin(plugin_id) + entry = await registry.get_plugin(db, plugin_id) if entry is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found") + # Record the installation in plugin_installations + installation = PluginInstallation( + plugin_id=plugin_id, + user_id=current_user.id, + ) + db.add(installation) + await db.flush() + await revenue_share.record_install( + db, plugin_id=plugin_id, user_id=current_user.id, amount_cents=entry["manifest"].price_cents, @@ -100,7 +131,18 @@ async def install_plugin( async def uninstall_plugin( plugin_id: str, current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), ) -> dict[str, bool]: """Unregister a plugin installation.""" - await registry.record_uninstall(plugin_id) + result = await db.execute( + select(PluginInstallation).where( + PluginInstallation.plugin_id == plugin_id, + PluginInstallation.user_id == current_user.id, + ) + ) + installation = result.scalar_one_or_none() + if installation is not None: + await db.delete(installation) + await db.commit() + await registry.record_uninstall(db, plugin_id) return {"ok": True} diff --git a/app/api/routes/storage.py b/app/api/routes/storage.py index beb5747..d7f8864 100644 --- a/app/api/routes/storage.py +++ b/app/api/routes/storage.py @@ -1,20 +1,23 @@ """Storage routes: CRUD for E2E-encrypted cloud records. -Blobs are stored in S3 via BlobStore. Record metadata is kept in an -in-memory dict until Step 12 migrates it to PostgreSQL (storage_records table). +Blobs are stored in S3 via BlobStore. Record metadata is persisted in the +PostgreSQL ``storage_records`` table. """ from __future__ import annotations -import time import uuid from typing import Any from fastapi import APIRouter, Depends, HTTPException, Query, Response, status from pydantic import BaseModel +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user from app.billing.tier_manager import tier_manager +from app.db import get_session +from app.models import StorageRecord from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile from app.storage.blob_store import BlobStore from app.storage.encryption import reject_if_tampered @@ -23,9 +26,6 @@ router = APIRouter(prefix="/storage", tags=["storage"]) _blob_store = BlobStore() -# In-memory record metadata — replaced by PostgreSQL storage_records table in Step 12 -_records: dict[str, dict[str, Any]] = {} - # ── Local response schemas ───────────────────────────────────────────── @@ -44,17 +44,34 @@ class _RecordMeta(BaseModel): # ── Helpers ──────────────────────────────────────────────────────────── -def _check_quota(user_id: str, additional_bytes: int) -> None: - """Raise HTTP 402 if adding ``additional_bytes`` would exceed the tier limit.""" - current = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id) - tier_manager.enforce_quota(user_id, current_bytes=current, additional_bytes=additional_bytes) +async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int: + """Return total bytes stored by *user_id*.""" + result = await db.execute( + select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where( + StorageRecord.user_id == user_id + ) + ) + return int(result.scalar_one()) -def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]: - """Look up a record and verify ownership. Always returns 404 on mismatch +async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None: + """Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit.""" + current = await _current_usage_bytes(user.id, db) + tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes) + + +async def _get_record_for_user( + record_id: str, user_id: str, db: AsyncSession +) -> StorageRecord: + """Look up a record and verify ownership. Returns 404 on mismatch to prevent user enumeration attacks.""" - record = _records.get(record_id) - if record is None or record["user_id"] != user_id: + result = await db.execute( + select(StorageRecord).where( + StorageRecord.id == record_id, StorageRecord.user_id == user_id + ) + ) + record = result.scalar_one_or_none() + if record is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found") return record @@ -65,30 +82,32 @@ def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]: async def create_record( body: StorageRecordCreate, current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), ) -> _CreateResponse: """Upload a new E2E-encrypted blob. Verifies checksum before storing.""" reject_if_tampered(body.blob, body.checksum) - _check_quota(current_user.id, len(body.blob)) + await _check_quota(current_user, len(body.blob), db) record_id = str(uuid.uuid4()) - now = int(time.time() * 1000) s3_key = await _blob_store.upload( current_user.id, body.table, record_id, body.blob, body.checksum ) - _records[record_id] = { - "id": record_id, - "user_id": current_user.id, - "table": body.table, - "s3_key": s3_key, - "checksum": body.checksum, - "size_bytes": len(body.blob), - "created_at": now, - "updated_at": now, - } + record = StorageRecord( + id=record_id, + user_id=current_user.id, + table_name=body.table, + s3_key=s3_key, + checksum=body.checksum, + size_bytes=len(body.blob), + ) + db.add(record) + await db.commit() + await db.refresh(record) - return _CreateResponse(id=record_id, created_at=now) + created_at_ms = int(record.created_at.timestamp() * 1000) + return _CreateResponse(id=record_id, created_at=created_at_ms) @router.get("/records", response_model=list[_RecordMeta]) @@ -97,23 +116,26 @@ async def list_records( page: int = Query(default=1, ge=1), limit: int = Query(default=50, ge=1, le=200), current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), ) -> list[_RecordMeta]: """List record metadata for the authenticated user. Blob bytes are never returned.""" - all_records = [ - r for r in _records.values() - if r["user_id"] == current_user.id and (table is None or r["table"] == table) - ] - start = (page - 1) * limit - page_records = all_records[start : start + limit] + query = select(StorageRecord).where(StorageRecord.user_id == current_user.id) + if table is not None: + query = query.where(StorageRecord.table_name == table) + query = query.offset((page - 1) * limit).limit(limit) + + result = await db.execute(query) + rows = result.scalars().all() + return [ _RecordMeta( - id=r["id"], - table=r["table"], - checksum=r["checksum"], - created_at=r["created_at"], - updated_at=r["updated_at"], + id=r.id, + table=r.table_name, + checksum=r.checksum, + created_at=int(r.created_at.timestamp() * 1000), + updated_at=int(r.updated_at.timestamp() * 1000), ) - for r in page_records + for r in rows ] @@ -121,14 +143,15 @@ async def list_records( async def download_record( record_id: str, current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), ) -> Response: """Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header.""" - record = _get_record_for_user(record_id, current_user.id) - blob = await _blob_store.download(current_user.id, record["s3_key"]) + record = await _get_record_for_user(record_id, current_user.id, db) + blob = await _blob_store.download(current_user.id, record.s3_key) return Response( content=blob, media_type="application/octet-stream", - headers={"X-Checksum": record["checksum"]}, + headers={"X-Checksum": record.checksum}, ) @@ -137,23 +160,24 @@ async def update_record( record_id: str, body: StorageRecordUpdate, current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), ) -> dict[str, bool]: """Replace the blob for an existing record. Verifies checksum before storing.""" - record = _get_record_for_user(record_id, current_user.id) + record = await _get_record_for_user(record_id, current_user.id, db) reject_if_tampered(body.blob, body.checksum) - delta = len(body.blob) - record["size_bytes"] + delta = len(body.blob) - record.size_bytes if delta > 0: - _check_quota(current_user.id, delta) + await _check_quota(current_user, delta, db) s3_key = await _blob_store.upload( - current_user.id, record["table"], record_id, body.blob, body.checksum + current_user.id, record.table_name, record_id, body.blob, body.checksum ) - record["s3_key"] = s3_key - record["checksum"] = body.checksum - record["size_bytes"] = len(body.blob) - record["updated_at"] = int(time.time() * 1000) + record.s3_key = s3_key + record.checksum = body.checksum + record.size_bytes = len(body.blob) + await db.commit() return {"ok": True} @@ -162,9 +186,11 @@ async def update_record( async def delete_record( record_id: str, current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), ) -> dict[str, bool]: """Delete a record and its S3 blob.""" - record = _get_record_for_user(record_id, current_user.id) - await _blob_store.delete(current_user.id, record["s3_key"]) - del _records[record_id] + record = await _get_record_for_user(record_id, current_user.id, db) + await _blob_store.delete(current_user.id, record.s3_key) + await db.delete(record) + await db.commit() return {"ok": True} diff --git a/app/marketplace/plugin_registry.py b/app/marketplace/plugin_registry.py index 239f655..0bc7fbe 100644 --- a/app/marketplace/plugin_registry.py +++ b/app/marketplace/plugin_registry.py @@ -1,8 +1,7 @@ -"""Plugin catalog registry. +"""Plugin catalog registry backed by PostgreSQL. Maintains the authoritative list of plugins, their review status, and -aggregate install counts. Storage is in-memory until Step 12 migrates to -the ``plugins`` PostgreSQL table. +aggregate install counts. All data is persisted in the ``plugins`` table. Module-level singleton:: @@ -11,144 +10,103 @@ Module-level singleton:: from __future__ import annotations -import copy -import time -import uuid +import json from typing import Any, Literal +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models import Plugin from app.schemas import PluginListResponse, PluginManifest -# ── Pre-seeded approved plugins (mirrors the Step 8 stub catalog) ───── - -_SEED_PLUGINS: list[dict[str, Any]] = [ - { - "manifest": PluginManifest( - id="plugin-github-sync", - name="GitHub Sync", - description="Sync tasks with GitHub Issues and pull requests.", - version="1.0.0", - author="Adiuva", - permissions=["read:tasks", "write:tasks"], - category="productivity", - price_cents=0, - ), - "status": "approved", - "s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip", - "install_count": 0, - "avg_rating": 0.0, - "rejection_reason": None, - "submitted_at": int(time.time()), - }, - { - "manifest": PluginManifest( - id="plugin-slack-notify", - name="Slack Notifier", - description="Post task and checkpoint updates to Slack channels.", - version="1.2.0", - author="Adiuva", - permissions=["read:tasks", "read:checkpoints"], - category="communication", - price_cents=499, - ), - "status": "approved", - "s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip", - "install_count": 0, - "avg_rating": 0.0, - "rejection_reason": None, - "submitted_at": int(time.time()), - }, - { - "manifest": PluginManifest( - id="plugin-time-tracker", - name="Time Tracker", - description="Track time spent on tasks with automatic reporting.", - version="0.9.1", - author="Third Party", - permissions=["read:tasks", "write:tasks"], - category="productivity", - price_cents=999, - ), - "status": "approved", - "s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip", - "install_count": 0, - "avg_rating": 0.0, - "rejection_reason": None, - "submitted_at": int(time.time()), - }, -] - _PAGE_SIZE = 20 +def _plugin_to_manifest(p: Plugin) -> PluginManifest: + """Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``.""" + try: + permissions = json.loads(p.permissions) if p.permissions else [] + except (json.JSONDecodeError, TypeError): + permissions = [] + return PluginManifest( + id=p.id, + name=p.name, + description=p.description, + version=p.version, + author=p.author_name, + permissions=permissions, + category=p.category, + price_cents=p.price_cents, + ) + + class PluginRegistry: - """In-process plugin catalog. + """PostgreSQL-backed plugin catalog. - All mutating methods are ``async`` to make the future DB swap transparent - to callers. + All methods accept an ``AsyncSession`` parameter so the calling route + controls the session lifecycle. """ - def __init__(self) -> None: - # plugin_id → entry dict (deep-copied so each instance is independent) - self._catalog: dict[str, dict[str, Any]] = { - e["manifest"].id: copy.deepcopy(e) for e in _SEED_PLUGINS - } - # ── Queries ────────────────────────────────────────────────────── async def list_plugins( self, + db: AsyncSession, category: str | None = None, query: str | None = None, page: int = 1, sort: Literal["rating", "installs", "newest"] = "newest", ) -> PluginListResponse: """Return a page of approved plugins, optionally filtered and sorted.""" - entries = [e for e in self._catalog.values() if e["status"] == "approved"] + base = select(Plugin).where(Plugin.status == "approved") if category: - entries = [e for e in entries if e["manifest"].category == category] - + base = base.where(Plugin.category == category) if query: - q_lower = query.lower() - entries = [ - e - for e in entries - if q_lower in e["manifest"].name.lower() - or q_lower in e["manifest"].description.lower() - ] + pattern = f"%{query}%" + base = base.where( + Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern) + ) + # Count + count_q = select(func.count()).select_from(base.subquery()) + total = (await db.execute(count_q)).scalar_one() + + # Sort if sort == "installs": - entries = sorted(entries, key=lambda e: e["install_count"], reverse=True) + base = base.order_by(Plugin.install_count.desc()) elif sort == "rating": - entries = sorted(entries, key=lambda e: e["avg_rating"], reverse=True) - # "newest" = catalog insertion order (dict preserves insertion in Python 3.7+) + base = base.order_by(Plugin.avg_rating.desc()) + else: # newest + base = base.order_by(Plugin.created_at.desc()) - total = len(entries) - start = (page - 1) * _PAGE_SIZE - page_entries = entries[start : start + _PAGE_SIZE] + base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE) + rows = (await db.execute(base)).scalars().all() return PluginListResponse( - plugins=[e["manifest"] for e in page_entries], + plugins=[_plugin_to_manifest(r) for r in rows], total=total, page=page, ) - async def get_plugin(self, plugin_id: str) -> dict[str, Any] | None: + async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None: """Return ``{manifest, status, install_count, avg_rating}`` or ``None``.""" - entry = self._catalog.get(plugin_id) - if entry is None: + result = await db.execute(select(Plugin).where(Plugin.id == plugin_id)) + p = result.scalar_one_or_none() + if p is None: return None return { - "manifest": entry["manifest"], - "status": entry["status"], - "install_count": entry["install_count"], - "avg_rating": entry["avg_rating"], + "manifest": _plugin_to_manifest(p), + "status": p.status, + "install_count": p.install_count, + "avg_rating": p.avg_rating, } # ── Mutations ──────────────────────────────────────────────────── async def submit_plugin( self, + db: AsyncSession, manifest: PluginManifest, package_s3_key: str, ) -> str: @@ -157,54 +115,97 @@ class PluginRegistry: Returns the plugin_id. If a plugin with the same id already exists it is overwritten (re-submission after rejection). """ - plugin_id = manifest.id or str(uuid.uuid4()) - self._catalog[plugin_id] = { - "manifest": manifest, - "status": "pending_review", - "s3_package_key": package_s3_key, - "install_count": 0, - "avg_rating": 0.0, - "rejection_reason": None, - "submitted_at": int(time.time()), - } + plugin_id = manifest.id + existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id)) + row = existing.scalar_one_or_none() + + if row is not None: + row.name = manifest.name + row.description = manifest.description + row.version = manifest.version + row.author_name = manifest.author + row.category = manifest.category + row.price_cents = manifest.price_cents + row.permissions = json.dumps(manifest.permissions) + row.status = "pending_review" + row.s3_package_key = package_s3_key + row.rejection_reason = None + else: + row = Plugin( + id=plugin_id, + name=manifest.name, + description=manifest.description, + version=manifest.version, + author_name=manifest.author, + category=manifest.category, + price_cents=manifest.price_cents, + permissions=json.dumps(manifest.permissions), + status="pending_review", + s3_package_key=package_s3_key, + install_count=0, + avg_rating=0.0, + ) + db.add(row) + await db.commit() return plugin_id - async def approve_plugin(self, plugin_id: str) -> None: + async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None: """Set *plugin_id* status to ``'approved'``. Raises ``KeyError`` if the plugin is not found. """ - if plugin_id not in self._catalog: + result = await db.execute(select(Plugin).where(Plugin.id == plugin_id)) + row = result.scalar_one_or_none() + if row is None: raise KeyError(f"Plugin not found: {plugin_id}") - self._catalog[plugin_id]["status"] = "approved" - self._catalog[plugin_id]["rejection_reason"] = None + row.status = "approved" + row.rejection_reason = None + await db.commit() - async def reject_plugin(self, plugin_id: str, reason: str) -> None: + async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None: """Set *plugin_id* status to ``'rejected'`` and record the reason. Raises ``KeyError`` if the plugin is not found. """ - if plugin_id not in self._catalog: + result = await db.execute(select(Plugin).where(Plugin.id == plugin_id)) + row = result.scalar_one_or_none() + if row is None: raise KeyError(f"Plugin not found: {plugin_id}") - self._catalog[plugin_id]["status"] = "rejected" - self._catalog[plugin_id]["rejection_reason"] = reason + row.status = "rejected" + row.rejection_reason = reason + await db.commit() - async def record_install(self, plugin_id: str) -> None: + async def record_install(self, db: AsyncSession, plugin_id: str) -> None: """Increment the install count for *plugin_id* (no-op if not found).""" - if plugin_id in self._catalog: - self._catalog[plugin_id]["install_count"] += 1 + result = await db.execute(select(Plugin).where(Plugin.id == plugin_id)) + row = result.scalar_one_or_none() + if row is not None: + row.install_count = row.install_count + 1 + await db.commit() - async def record_uninstall(self, plugin_id: str) -> None: + async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None: """Decrement the install count for *plugin_id*, floored at 0.""" - if plugin_id in self._catalog: - current = self._catalog[plugin_id]["install_count"] - self._catalog[plugin_id]["install_count"] = max(0, current - 1) + result = await db.execute(select(Plugin).where(Plugin.id == plugin_id)) + row = result.scalar_one_or_none() + if row is not None: + row.install_count = max(0, row.install_count - 1) + await db.commit() # ── Internal helpers used by ReviewQueue ───────────────────────── - def _get_pending_entries(self) -> list[dict[str, Any]]: - """Return all entries with status='pending_review' (synchronous helper).""" - return [e for e in self._catalog.values() if e["status"] == "pending_review"] + async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]: + """Return all entries with status='pending_review'.""" + result = await db.execute( + select(Plugin).where(Plugin.status == "pending_review") + ) + rows = result.scalars().all() + return [ + { + "manifest": _plugin_to_manifest(r), + "submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0, + } + for r in rows + ] # Module-level singleton diff --git a/app/marketplace/plugin_review.py b/app/marketplace/plugin_review.py index 3f63bd7..5e4aeec 100644 --- a/app/marketplace/plugin_review.py +++ b/app/marketplace/plugin_review.py @@ -1,4 +1,4 @@ -"""Plugin review workflow. +"""Plugin review workflow backed by PostgreSQL. Manages the approval queue for newly submitted plugins and enforces a security checklist before any plugin is made visible in the marketplace. @@ -11,10 +11,12 @@ Module-level singleton:: from __future__ import annotations import re -import time from typing import Any, Literal +from sqlalchemy.ext.asyncio import AsyncSession + from app.marketplace.plugin_registry import registry +from app.models import PluginReview as PluginReviewModel from app.schemas import PluginManifest # ── Security policy ─────────────────────────────────────────────────── @@ -72,20 +74,16 @@ def validate_manifest(manifest: PluginManifest) -> None: class ReviewQueue: """Approval queue for pending plugin submissions. - Delegates status changes to the shared ``PluginRegistry`` singleton so - there is a single source of truth for plugin state. + Delegates status changes to the shared ``PluginRegistry`` singleton. + Review records are persisted in the ``plugin_reviews`` table. """ - def __init__(self) -> None: - # Completed reviews — Step 12 stores in plugin_reviews table - self._reviews: list[dict[str, Any]] = [] - - async def get_pending(self) -> list[dict[str, Any]]: + async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]: """Return all plugins currently awaiting review. Each item is ``{plugin_id, manifest, submitted_at}``. """ - entries = registry._get_pending_entries() + entries = await registry.get_pending_entries(db) return [ { "plugin_id": e["manifest"].id, @@ -97,6 +95,7 @@ class ReviewQueue: async def submit_review( self, + db: AsyncSession, plugin_id: str, reviewer_id: str, decision: Literal["approved", "rejected"], @@ -108,19 +107,18 @@ class ReviewQueue: ``KeyError`` if *plugin_id* is not found in the registry. """ if decision == "approved": - await registry.approve_plugin(plugin_id) + await registry.approve_plugin(db, plugin_id) else: - await registry.reject_plugin(plugin_id, reason=notes) + await registry.reject_plugin(db, plugin_id, reason=notes) - self._reviews.append( - { - "plugin_id": plugin_id, - "reviewer_id": reviewer_id, - "decision": decision, - "notes": notes, - "reviewed_at": int(time.time()), - } + review = PluginReviewModel( + plugin_id=plugin_id, + reviewer_id=reviewer_id, + decision=decision, + notes=notes, ) + db.add(review) + await db.commit() # Module-level singleton diff --git a/app/marketplace/revenue_share.py b/app/marketplace/revenue_share.py index 4c8c1dd..05f1d9f 100644 --- a/app/marketplace/revenue_share.py +++ b/app/marketplace/revenue_share.py @@ -1,8 +1,8 @@ -"""Revenue share tracking and Stripe Connect payouts. +"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL. Records every plugin installation as a revenue event and facilitates -70 % / 30 % payouts to developers via Stripe Connect. Storage is -in-memory until Step 12 migrates to the ``revenue_events`` table. +70 % / 30 % payouts to developers via Stripe Connect. Data is persisted +in the ``revenue_events`` table. Module-level singleton:: @@ -12,13 +12,16 @@ Module-level singleton:: from __future__ import annotations import logging -import time +from datetime import datetime, timezone from typing import Any import stripe as stripe_lib +from sqlalchemy import extract, func, select +from sqlalchemy.ext.asyncio import AsyncSession from app.config.settings import settings from app.marketplace.plugin_registry import registry +from app.models import Plugin, RevenueEvent logger = logging.getLogger(__name__) @@ -35,10 +38,6 @@ class RevenueShare: is not configured, consistent with the rest of the billing layer. """ - def __init__(self) -> None: - # Step 12 replaces with revenue_events DB table - self._events: list[dict[str, Any]] = [] - # ── Helpers ────────────────────────────────────────────────────── @staticmethod @@ -54,6 +53,7 @@ class RevenueShare: async def record_install( self, + db: AsyncSession, plugin_id: str, user_id: str, amount_cents: int, @@ -72,11 +72,12 @@ class RevenueShare: stripe_transfer_id: str | None = None if amount_cents > 0 and self._stripe_configured(): - plugin_entry = registry._catalog.get(plugin_id) + # Look up the plugin's author Stripe account from the DB + result = await db.execute(select(Plugin).where(Plugin.id == plugin_id)) + plugin_row = result.scalar_one_or_none() developer_stripe_account: str | None = None - if plugin_entry: - # Step 12: look up developer's Stripe account from DB - # For now, the author field is used as a placeholder key. + if plugin_row and plugin_row.author_id: + # Future: look up user.stripe_connect_account_id developer_stripe_account = None # no real account yet if developer_stripe_account: @@ -103,22 +104,21 @@ class RevenueShare: plugin_id, ) - self._events.append( - { - "plugin_id": plugin_id, - "user_id": user_id, - "amount_cents": amount_cents, - "developer_share_cents": developer_share_cents, - "stripe_transfer_id": stripe_transfer_id, - "paid_at": None, - "created_at": int(time.time()), - } + event = RevenueEvent( + plugin_id=plugin_id, + user_id=user_id, + amount_cents=amount_cents, + developer_share_cents=developer_share_cents, + stripe_transfer_id=stripe_transfer_id, ) + db.add(event) + await db.commit() - await registry.record_install(plugin_id) + await registry.record_install(db, plugin_id) async def get_earnings( self, + db: AsyncSession, developer_id: str, period: str | None = None, ) -> dict[str, Any]: @@ -136,54 +136,81 @@ class RevenueShare: "developer_share_cents": int, } """ - # Find plugin ids belonging to this developer - developer_plugin_ids: set[str] = { - pid - for pid, entry in registry._catalog.items() - if entry["manifest"].author == developer_id - } + # Find plugin ids belonging to this developer (by author_name match) + plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id) + plugin_result = await db.execute(plugin_q) + developer_plugin_ids = [row[0] for row in plugin_result.all()] - events = [e for e in self._events if e["plugin_id"] in developer_plugin_ids] + if not developer_plugin_ids: + return { + "developer_id": developer_id, + "period": period, + "total_installs": 0, + "total_revenue_cents": 0, + "developer_share_cents": 0, + } + + query = select( + func.count().label("total_installs"), + func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"), + func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"), + ).where(RevenueEvent.plugin_id.in_(developer_plugin_ids)) if period: - # Filter by YYYY-MM prefix of the created_at timestamp - events = [ - e - for e in events - if time.strftime("%Y-%m", time.gmtime(e["created_at"])) == period - ] + # Filter by YYYY-MM: extract year and month from created_at + try: + year, month = period.split("-") + query = query.where( + extract("year", RevenueEvent.created_at) == int(year), + extract("month", RevenueEvent.created_at) == int(month), + ) + except ValueError: + pass # invalid period format — return all + + result = await db.execute(query) + row = result.one() return { "developer_id": developer_id, "period": period, - "total_installs": len(events), - "total_revenue_cents": sum(e["amount_cents"] for e in events), - "developer_share_cents": sum(e["developer_share_cents"] for e in events), + "total_installs": row.total_installs, + "total_revenue_cents": row.total_revenue, + "developer_share_cents": row.dev_share, } - async def payout_developer(self, plugin_id: str, period: str) -> None: + async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None: """Aggregate unpaid revenue for *period* and issue a Stripe Transfer. Marks processed events with ``paid_at`` timestamp. Stubs gracefully when Stripe is not configured. """ - unpaid = [ - e - for e in self._events - if e["plugin_id"] == plugin_id - and e["paid_at"] is None - and time.strftime("%Y-%m", time.gmtime(e["created_at"])) == period - ] + try: + year, month = period.split("-") + year_int, month_int = int(year), int(month) + except ValueError: + logger.warning("Invalid period format: %s", period) + return - total_dev_share = sum(e["developer_share_cents"] for e in unpaid) + result = await db.execute( + select(RevenueEvent).where( + RevenueEvent.plugin_id == plugin_id, + RevenueEvent.paid_at.is_(None), + extract("year", RevenueEvent.created_at) == year_int, + extract("month", RevenueEvent.created_at) == month_int, + ) + ) + unpaid = list(result.scalars().all()) + + total_dev_share = sum(e.developer_share_cents for e in unpaid) if total_dev_share <= 0 or not unpaid: logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period) return if self._stripe_configured(): - plugin_entry = registry._catalog.get(plugin_id) - developer_stripe_account: str | None = None # Step 12: fetch from DB - if plugin_entry and developer_stripe_account: + plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id)) + plugin_row = plugin_result.scalar_one_or_none() + developer_stripe_account: str | None = None # Future: fetch from DB + if plugin_row and developer_stripe_account: try: s = self._stripe() s.Transfer.create( @@ -196,9 +223,10 @@ class RevenueShare: logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc) return - paid_ts = int(time.time()) + paid_ts = datetime.now(timezone.utc) for event in unpaid: - event["paid_at"] = paid_ts + event.paid_at = paid_ts + await db.commit() # Module-level singleton diff --git a/app/models.py b/app/models.py index ee5ba03..f259fca 100644 --- a/app/models.py +++ b/app/models.py @@ -32,9 +32,9 @@ from sqlalchemy import ( String, Text, UniqueConstraint, + Uuid, func, ) -from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column, relationship from app.db import Base @@ -64,7 +64,7 @@ class User(Base): __tablename__ = "users" id: Mapped[str] = mapped_column( - UUID(as_uuid=False), primary_key=True, default=_uuid + Uuid(as_uuid=False), primary_key=True, default=_uuid ) email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True) password_hash: Mapped[str] = mapped_column(String(255), nullable=False) @@ -89,10 +89,10 @@ class RefreshToken(Base): __tablename__ = "refresh_tokens" id: Mapped[str] = mapped_column( - UUID(as_uuid=False), primary_key=True, default=_uuid + Uuid(as_uuid=False), primary_key=True, default=_uuid ) user_id: Mapped[str] = mapped_column( - UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True + Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True ) token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True) expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) @@ -107,10 +107,10 @@ class Subscription(Base): __tablename__ = "subscriptions" id: Mapped[str] = mapped_column( - UUID(as_uuid=False), primary_key=True, default=_uuid + Uuid(as_uuid=False), primary_key=True, default=_uuid ) user_id: Mapped[str] = mapped_column( - UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), + Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, unique=True, index=True ) stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True) @@ -128,10 +128,10 @@ class StorageRecord(Base): __tablename__ = "storage_records" id: Mapped[str] = mapped_column( - UUID(as_uuid=False), primary_key=True, default=_uuid + Uuid(as_uuid=False), primary_key=True, default=_uuid ) user_id: Mapped[str] = mapped_column( - UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True + Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True ) table_name: Mapped[str] = mapped_column(String(100), nullable=False) s3_key: Mapped[str] = mapped_column(String(500), nullable=False) @@ -149,10 +149,10 @@ class BackupMetadata(Base): __tablename__ = "backup_metadata" id: Mapped[str] = mapped_column( - UUID(as_uuid=False), primary_key=True, default=_uuid + Uuid(as_uuid=False), primary_key=True, default=_uuid ) user_id: Mapped[str] = mapped_column( - UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True + Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True ) s3_key: Mapped[str] = mapped_column(String(500), nullable=False) version: Mapped[int] = mapped_column(Integer, nullable=False) @@ -173,7 +173,7 @@ class Plugin(Base): version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0") # nullable until developer account system is built author_id: Mapped[str | None] = mapped_column( - UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True + Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True ) author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="") category: Mapped[str] = mapped_column(String(100), nullable=False, default="") @@ -207,13 +207,13 @@ class PluginInstallation(Base): __table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),) id: Mapped[str] = mapped_column( - UUID(as_uuid=False), primary_key=True, default=_uuid + Uuid(as_uuid=False), primary_key=True, default=_uuid ) plugin_id: Mapped[str] = mapped_column( String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True ) user_id: Mapped[str] = mapped_column( - UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True + Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True ) installed_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), nullable=False, server_default=func.now() @@ -226,13 +226,13 @@ class PluginReview(Base): __tablename__ = "plugin_reviews" id: Mapped[str] = mapped_column( - UUID(as_uuid=False), primary_key=True, default=_uuid + Uuid(as_uuid=False), primary_key=True, default=_uuid ) plugin_id: Mapped[str] = mapped_column( String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True ) reviewer_id: Mapped[str | None] = mapped_column( - UUID(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True + Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True ) decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False) notes: Mapped[str | None] = mapped_column(Text, nullable=True) @@ -250,13 +250,13 @@ class RevenueEvent(Base): __tablename__ = "revenue_events" id: Mapped[str] = mapped_column( - UUID(as_uuid=False), primary_key=True, default=_uuid + Uuid(as_uuid=False), primary_key=True, default=_uuid ) plugin_id: Mapped[str] = mapped_column( String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True ) user_id: Mapped[str] = mapped_column( - UUID(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True + Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True ) amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0) developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0) diff --git a/requirements.txt b/requirements.txt index f2465ff..b0d98ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,8 +15,10 @@ bcrypt>=4.2.0 python-dotenv>=1.0.0 httpx>=0.28.0 websockets>=14.0 +psycopg2-binary>=2.9.0 pytest>=8.0.0 pytest-asyncio>=0.24.0 +aiosqlite>=0.20.0 moto[s3]>=5.0.0 pinecone>=5.0.0 qdrant-client>=1.7.0 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..a4837d7 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,208 @@ +"""Shared test fixtures for database-backed tests. + +Provides an async SQLite in-memory engine that auto-creates all tables, +a per-test session, and a FastAPI ``TestClient`` wired to use it. +""" + +from __future__ import annotations + +import json +import time +import uuid +from collections.abc import AsyncGenerator, Generator + +import pytest +import pytest_asyncio +from fastapi.testclient import TestClient +from jose import jwt +from sqlalchemy import StaticPool, event +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.config.settings import settings +from app.db import Base, get_session +from app.main import app +from app.models import Plugin, Subscription, User + +# ── Fixed test user IDs (one per tier) ─────────────────────────────── + +TEST_USER_IDS: dict[str, str] = { + "free": "00000000-0000-0000-0000-000000000001", + "pro": "00000000-0000-0000-0000-000000000002", + "power": "00000000-0000-0000-0000-000000000003", + "team": "00000000-0000-0000-0000-000000000004", +} + +# ── Async SQLite engine ────────────────────────────────────────────── + +_TEST_ENGINE = create_async_engine( + "sqlite+aiosqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, +) + +_TestSessionLocal = async_sessionmaker( + _TEST_ENGINE, + expire_on_commit=False, +) + + +# Enable foreign key enforcement for SQLite (off by default). +@event.listens_for(_TEST_ENGINE.sync_engine, "connect") +def _set_sqlite_pragma(dbapi_conn, _connection_record): # noqa: ANN001 + cursor = dbapi_conn.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + +# ── Fixtures ───────────────────────────────────────────────────────── + +@pytest_asyncio.fixture(autouse=True) +async def _create_tables(): + """Create all tables before each test, seed test users, then drop after.""" + async with _TEST_ENGINE.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + # Seed one User + Subscription per tier so FK constraints and auth work. + async with _TestSessionLocal() as session: + for tier, uid in TEST_USER_IDS.items(): + session.add(User( + id=uid, + email=f"{tier}@test.com", + password_hash="$2b$12$fakehashfortesting000000000000000000000000000", + tier=tier, + )) + session.add(Subscription( + id=str(uuid.uuid4()), + user_id=uid, + tier=tier, + stripe_subscription_id=f"sub_test_{tier}", + status="active", + )) + await session.commit() + + yield + async with _TEST_ENGINE.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + + +@pytest_asyncio.fixture +async def db_session() -> AsyncGenerator[AsyncSession, None]: + """Yield a per-test async DB session.""" + async with _TestSessionLocal() as session: + yield session + + +@pytest.fixture +def client(db_session: AsyncSession) -> Generator[TestClient, None, None]: # noqa: ANN001 + """FastAPI test client with ``get_session`` overridden to use the test DB.""" + + async def _override_get_session() -> AsyncGenerator[AsyncSession, None]: + yield db_session + + app.dependency_overrides[get_session] = _override_get_session + with TestClient(app) as c: + yield c + app.dependency_overrides.pop(get_session, None) + + +# ── Seed data helpers ──────────────────────────────────────────────── + +_SEED_PLUGINS = [ + Plugin( + id="plugin-github-sync", + name="GitHub Sync", + description="Sync tasks with GitHub Issues and pull requests.", + version="1.0.0", + author_name="Adiuva", + category="productivity", + price_cents=0, + permissions=json.dumps(["read:tasks", "write:tasks"]), + status="approved", + s3_package_key="plugins/plugin-github-sync/1.0.0/package.zip", + install_count=0, + avg_rating=0.0, + ), + Plugin( + id="plugin-slack-notify", + name="Slack Notifier", + description="Post task and checkpoint updates to Slack channels.", + version="1.2.0", + author_name="Adiuva", + category="communication", + price_cents=499, + permissions=json.dumps(["read:tasks", "read:checkpoints"]), + status="approved", + s3_package_key="plugins/plugin-slack-notify/1.2.0/package.zip", + install_count=0, + avg_rating=0.0, + ), + Plugin( + id="plugin-time-tracker", + name="Time Tracker", + description="Track time spent on tasks with automatic reporting.", + version="0.9.1", + author_name="Third Party", + category="productivity", + price_cents=999, + permissions=json.dumps(["read:tasks", "write:tasks"]), + status="approved", + s3_package_key="plugins/plugin-time-tracker/0.9.1/package.zip", + install_count=0, + avg_rating=0.0, + ), +] + + +@pytest_asyncio.fixture +async def seed_plugins(db_session: AsyncSession) -> list[Plugin]: + """Insert the 3 default approved plugins and return them.""" + plugins = [] + for template in _SEED_PLUGINS: + p = Plugin( + id=template.id, + name=template.name, + description=template.description, + version=template.version, + author_name=template.author_name, + category=template.category, + price_cents=template.price_cents, + permissions=template.permissions, + status=template.status, + s3_package_key=template.s3_package_key, + install_count=template.install_count, + avg_rating=template.avg_rating, + ) + db_session.add(p) + plugins.append(p) + await db_session.commit() + return plugins + + +# ── JWT helpers ────────────────────────────────────────────────────── + + +def make_jwt( + tier: str = "power", + user_id: str | None = None, + email: str | None = None, +) -> str: + """Create a signed test JWT. + + Uses the fixed ``TEST_USER_IDS`` mapping so the auth middleware can + find the corresponding ``Subscription`` row in the test database. + """ + uid = user_id or TEST_USER_IDS.get(tier, str(uuid.uuid4())) + now = int(time.time()) + payload = { + "sub": uid, + "email": email or f"{tier}@test.com", + "tier": tier, + "exp": now + 3600, + "iat": now, + } + return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM) + + +def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, str]: + """Return an Authorization header dict for the given tier.""" + return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"} diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 343a171..8721bbc 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -18,13 +18,30 @@ from fastapi.testclient import TestClient from jose import jwt from app.config.settings import settings +from app.db import get_session from app.main import app from app.schemas import ChatResponse +from tests.conftest import TEST_USER_IDS # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- +# --------------------------------------------------------------------------- +# Autouse: redirect all DB access to the in-memory SQLite test engine. +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def _override_db(db_session): + """Route all get_session calls to the test SQLite session.""" + async def _gen(): + yield db_session + + app.dependency_overrides[get_session] = _gen + yield + app.dependency_overrides.pop(get_session, None) + + _CHAT_BODY = { "message": "hello", "context": { @@ -74,14 +91,15 @@ class TestAuthMiddleware: """Tests exercised via GET /api/v1/auth/me.""" def test_valid_token_returns_profile(self) -> None: - uid = str(uuid.uuid4()) - token = _make_jwt(user_id=uid, email="alice@example.com", tier="pro") + # Use the seeded pro user so the subscription lookup returns 'pro'. + uid = TEST_USER_IDS["pro"] + token = _make_jwt(user_id=uid, email="pro@test.com", tier="pro") with TestClient(app) as client: resp = client.get("/api/v1/auth/me", headers=_auth_header(token)) assert resp.status_code == 200 data = resp.json() assert data["id"] == uid - assert data["email"] == "alice@example.com" + assert data["email"] == "pro@test.com" assert data["tier"] == "pro" def test_missing_token_returns_401(self) -> None: diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 81261e4..6a293ff 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -1,52 +1,34 @@ -"""Tests for Step 10: Plugin Marketplace. +"""Tests for Step 10+12: Plugin Marketplace (DB-backed). Covers: - - PluginRegistry: catalog management, filtering, sorting, install counts + - PluginRegistry: catalog management, filtering, sorting, install counts (PostgreSQL) - ReviewQueue: pending queue, review decisions, manifest security checklist - - RevenueShare: install event recording, earnings aggregation + - RevenueShare: install event recording, earnings aggregation (PostgreSQL) - Route integration: tier gate, list/get/install/uninstall via TestClient """ from __future__ import annotations -import time +import json import uuid import pytest import pytest_asyncio -from fastapi.testclient import TestClient -from jose import jwt -from unittest.mock import patch +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession -from app.config.settings import settings -from app.main import app from app.marketplace.plugin_registry import PluginRegistry from app.marketplace.plugin_review import ReviewQueue, validate_manifest from app.marketplace.revenue_share import RevenueShare +from app.models import Plugin, PluginReview as PluginReviewModel, RevenueEvent from app.schemas import PluginManifest +from tests.conftest import TEST_USER_IDS, auth_header # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- -def _make_jwt(tier: str = "power", user_id: str | None = None) -> str: - uid = user_id or str(uuid.uuid4()) - now = int(time.time()) - payload = { - "sub": uid, - "email": f"{uid[:8]}@example.com", - "tier": tier, - "exp": now + 3600, - "iat": now, - } - return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM) - - -def _auth(tier: str = "power") -> dict[str, str]: - return {"Authorization": f"Bearer {_make_jwt(tier)}"} - - def _fresh_manifest( plugin_id: str | None = None, category: str = "productivity", @@ -67,118 +49,150 @@ def _fresh_manifest( # --------------------------------------------------------------------------- -# PluginRegistry +# PluginRegistry (DB-backed) # --------------------------------------------------------------------------- class TestPluginRegistry: - """Each test uses a fresh PluginRegistry instance to avoid catalog pollution.""" + """Each test uses the conftest db_session fixture with a fresh in-memory DB.""" @pytest.fixture def reg(self) -> PluginRegistry: return PluginRegistry() @pytest.mark.asyncio - async def test_seed_plugins_are_approved(self, reg: PluginRegistry) -> None: - result = await reg.list_plugins() + async def test_seed_plugins_are_listed( + self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin] + ) -> None: + result = await reg.list_plugins(db_session) assert result.total == 3 assert all(p.id.startswith("plugin-") for p in result.plugins) @pytest.mark.asyncio - async def test_list_approved_only(self, reg: PluginRegistry) -> None: + async def test_list_approved_only( + self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin] + ) -> None: manifest = _fresh_manifest() - await reg.submit_plugin(manifest, "plugins/key.zip") - result = await reg.list_plugins() + await reg.submit_plugin(db_session, manifest, "plugins/key.zip") + result = await reg.list_plugins(db_session) ids = [p.id for p in result.plugins] assert manifest.id not in ids # still pending @pytest.mark.asyncio - async def test_list_filter_by_category(self, reg: PluginRegistry) -> None: - result = await reg.list_plugins(category="communication") + async def test_list_filter_by_category( + self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin] + ) -> None: + result = await reg.list_plugins(db_session, category="communication") assert result.total == 1 assert result.plugins[0].id == "plugin-slack-notify" @pytest.mark.asyncio - async def test_list_filter_by_query(self, reg: PluginRegistry) -> None: - result = await reg.list_plugins(query="time") + async def test_list_filter_by_query( + self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin] + ) -> None: + result = await reg.list_plugins(db_session, query="time") assert result.total == 1 assert result.plugins[0].id == "plugin-time-tracker" @pytest.mark.asyncio - async def test_list_sort_by_installs(self, reg: PluginRegistry) -> None: - await reg.record_install("plugin-slack-notify") - await reg.record_install("plugin-slack-notify") - result = await reg.list_plugins(sort="installs") + async def test_list_sort_by_installs( + self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin] + ) -> None: + await reg.record_install(db_session, "plugin-slack-notify") + await reg.record_install(db_session, "plugin-slack-notify") + result = await reg.list_plugins(db_session, sort="installs") assert result.plugins[0].id == "plugin-slack-notify" @pytest.mark.asyncio - async def test_get_plugin_found(self, reg: PluginRegistry) -> None: - entry = await reg.get_plugin("plugin-github-sync") + async def test_get_plugin_found( + self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin] + ) -> None: + entry = await reg.get_plugin(db_session, "plugin-github-sync") assert entry is not None assert entry["manifest"].id == "plugin-github-sync" assert "install_count" in entry @pytest.mark.asyncio - async def test_get_plugin_not_found(self, reg: PluginRegistry) -> None: - entry = await reg.get_plugin("no-such-plugin") + async def test_get_plugin_not_found( + self, reg: PluginRegistry, db_session: AsyncSession + ) -> None: + entry = await reg.get_plugin(db_session, "no-such-plugin") assert entry is None @pytest.mark.asyncio - async def test_submit_sets_pending(self, reg: PluginRegistry) -> None: + async def test_submit_sets_pending( + self, reg: PluginRegistry, db_session: AsyncSession + ) -> None: manifest = _fresh_manifest() - plugin_id = await reg.submit_plugin(manifest, "key.zip") + plugin_id = await reg.submit_plugin(db_session, manifest, "key.zip") assert plugin_id == manifest.id - assert reg._catalog[plugin_id]["status"] == "pending_review" + result = await db_session.execute(select(Plugin).where(Plugin.id == plugin_id)) + row = result.scalar_one() + assert row.status == "pending_review" @pytest.mark.asyncio - async def test_approve_makes_visible(self, reg: PluginRegistry) -> None: + async def test_approve_makes_visible( + self, reg: PluginRegistry, db_session: AsyncSession + ) -> None: manifest = _fresh_manifest() - await reg.submit_plugin(manifest, "key.zip") - await reg.approve_plugin(manifest.id) - result = await reg.list_plugins() + await reg.submit_plugin(db_session, manifest, "key.zip") + await reg.approve_plugin(db_session, manifest.id) + result = await reg.list_plugins(db_session) assert manifest.id in [p.id for p in result.plugins] @pytest.mark.asyncio - async def test_reject_stores_reason(self, reg: PluginRegistry) -> None: + async def test_reject_stores_reason( + self, reg: PluginRegistry, db_session: AsyncSession + ) -> None: manifest = _fresh_manifest() - await reg.submit_plugin(manifest, "key.zip") - await reg.reject_plugin(manifest.id, reason="Unsafe permissions") - assert reg._catalog[manifest.id]["status"] == "rejected" - assert reg._catalog[manifest.id]["rejection_reason"] == "Unsafe permissions" - result = await reg.list_plugins() - assert manifest.id not in [p.id for p in result.plugins] + await reg.submit_plugin(db_session, manifest, "key.zip") + await reg.reject_plugin(db_session, manifest.id, reason="Unsafe permissions") + result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id)) + row = result.scalar_one() + assert row.status == "rejected" + assert row.rejection_reason == "Unsafe permissions" + listed = await reg.list_plugins(db_session) + assert manifest.id not in [p.id for p in listed.plugins] @pytest.mark.asyncio - async def test_approve_unknown_raises_key_error(self, reg: PluginRegistry) -> None: + async def test_approve_unknown_raises_key_error( + self, reg: PluginRegistry, db_session: AsyncSession + ) -> None: with pytest.raises(KeyError): - await reg.approve_plugin("ghost-plugin") + await reg.approve_plugin(db_session, "ghost-plugin") @pytest.mark.asyncio - async def test_record_install_increments_count(self, reg: PluginRegistry) -> None: - await reg.record_install("plugin-github-sync") - entry = await reg.get_plugin("plugin-github-sync") + async def test_record_install_increments_count( + self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin] + ) -> None: + await reg.record_install(db_session, "plugin-github-sync") + entry = await reg.get_plugin(db_session, "plugin-github-sync") assert entry is not None assert entry["install_count"] == 1 @pytest.mark.asyncio - async def test_record_uninstall_decrements_count(self, reg: PluginRegistry) -> None: - await reg.record_install("plugin-github-sync") - await reg.record_install("plugin-github-sync") - await reg.record_uninstall("plugin-github-sync") - entry = await reg.get_plugin("plugin-github-sync") + async def test_record_uninstall_decrements_count( + self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin] + ) -> None: + await reg.record_install(db_session, "plugin-github-sync") + await reg.record_install(db_session, "plugin-github-sync") + await reg.record_uninstall(db_session, "plugin-github-sync") + entry = await reg.get_plugin(db_session, "plugin-github-sync") assert entry is not None assert entry["install_count"] == 1 @pytest.mark.asyncio - async def test_record_uninstall_floors_at_zero(self, reg: PluginRegistry) -> None: - await reg.record_uninstall("plugin-github-sync") # already 0 - entry = await reg.get_plugin("plugin-github-sync") + async def test_record_uninstall_floors_at_zero( + self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin] + ) -> None: + await reg.record_uninstall(db_session, "plugin-github-sync") + entry = await reg.get_plugin(db_session, "plugin-github-sync") assert entry is not None assert entry["install_count"] == 0 # --------------------------------------------------------------------------- -# ReviewQueue +# ReviewQueue (DB-backed) # --------------------------------------------------------------------------- @@ -188,37 +202,47 @@ class TestReviewQueue: return PluginRegistry() @pytest.fixture - def queue(self, reg: PluginRegistry) -> ReviewQueue: - # Patch the 'registry' name as bound inside plugin_review.py - with patch("app.marketplace.plugin_review.registry", reg): - yield ReviewQueue() + def queue(self) -> ReviewQueue: + return ReviewQueue() @pytest.mark.asyncio async def test_get_pending_returns_submitted_plugins( - self, reg: PluginRegistry, queue: ReviewQueue + self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession ) -> None: manifest = _fresh_manifest() - await reg.submit_plugin(manifest, "key.zip") - pending = await queue.get_pending() + await reg.submit_plugin(db_session, manifest, "key.zip") + pending = await queue.get_pending(db_session) assert any(p["plugin_id"] == manifest.id for p in pending) @pytest.mark.asyncio async def test_submit_review_approved( - self, reg: PluginRegistry, queue: ReviewQueue + self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession ) -> None: manifest = _fresh_manifest() - await reg.submit_plugin(manifest, "key.zip") - await queue.submit_review(manifest.id, "reviewer-1", "approved", "Looks good") - assert reg._catalog[manifest.id]["status"] == "approved" + await reg.submit_plugin(db_session, manifest, "key.zip") + await queue.submit_review(db_session, manifest.id, TEST_USER_IDS["power"], "approved", "Looks good") + result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id)) + row = result.scalar_one() + assert row.status == "approved" + # Check review row was persisted + review_result = await db_session.execute( + select(PluginReviewModel).where(PluginReviewModel.plugin_id == manifest.id) + ) + review = review_result.scalar_one() + assert review.decision == "approved" @pytest.mark.asyncio async def test_submit_review_rejected( - self, reg: PluginRegistry, queue: ReviewQueue + self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession ) -> None: manifest = _fresh_manifest() - await reg.submit_plugin(manifest, "key.zip") - await queue.submit_review(manifest.id, "reviewer-1", "rejected", "Bad permissions") - assert reg._catalog[manifest.id]["status"] == "rejected" + await reg.submit_plugin(db_session, manifest, "key.zip") + await queue.submit_review( + db_session, manifest.id, TEST_USER_IDS["power"], "rejected", "Bad permissions" + ) + result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id)) + row = result.scalar_one() + assert row.status == "rejected" def test_validate_manifest_ok(self) -> None: manifest = _fresh_manifest(permissions=["read:tasks", "write:notes"]) @@ -241,65 +265,66 @@ class TestReviewQueue: # --------------------------------------------------------------------------- -# RevenueShare +# RevenueShare (DB-backed) # --------------------------------------------------------------------------- class TestRevenueShare: @pytest.fixture - def reg(self) -> PluginRegistry: - return PluginRegistry() - - @pytest.fixture - def rs(self, reg: PluginRegistry) -> RevenueShare: - # Patch the 'registry' name as bound inside revenue_share.py - with patch("app.marketplace.revenue_share.registry", reg): - yield RevenueShare() + def rs(self) -> RevenueShare: + return RevenueShare() @pytest.mark.asyncio async def test_record_install_free_plugin( - self, reg: PluginRegistry, rs: RevenueShare + self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin] ) -> None: - await rs.record_install("plugin-github-sync", "user-1", amount_cents=0) - assert len(rs._events) == 1 - assert rs._events[0]["developer_share_cents"] == 0 + await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0) + result = await db_session.execute( + select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-github-sync") + ) + event = result.scalar_one() + assert event.developer_share_cents == 0 @pytest.mark.asyncio async def test_record_install_paid_plugin_no_stripe( - self, reg: PluginRegistry, rs: RevenueShare + self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin] ) -> None: - # No STRIPE_SECRET_KEY configured in test env — should not crash - await rs.record_install("plugin-slack-notify", "user-2", amount_cents=499) - assert len(rs._events) == 1 - assert rs._events[0]["amount_cents"] == 499 - assert rs._events[0]["developer_share_cents"] == int(499 * 0.70) + await rs.record_install( + db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499 + ) + result = await db_session.execute( + select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-slack-notify") + ) + event = result.scalar_one() + assert event.amount_cents == 499 + assert event.developer_share_cents == int(499 * 0.70) @pytest.mark.asyncio async def test_record_install_increments_registry_count( - self, reg: PluginRegistry, rs: RevenueShare + self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin] ) -> None: - await rs.record_install("plugin-github-sync", "user-1", amount_cents=0) - entry = await reg.get_plugin("plugin-github-sync") + reg = PluginRegistry() + await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0) + entry = await reg.get_plugin(db_session, "plugin-github-sync") assert entry is not None assert entry["install_count"] == 1 @pytest.mark.asyncio async def test_get_earnings_empty( - self, reg: PluginRegistry, rs: RevenueShare + self, rs: RevenueShare, db_session: AsyncSession ) -> None: - result = await rs.get_earnings("unknown-dev") + result = await rs.get_earnings(db_session, "unknown-dev") assert result["total_installs"] == 0 assert result["total_revenue_cents"] == 0 assert result["developer_share_cents"] == 0 @pytest.mark.asyncio async def test_get_earnings_aggregates( - self, reg: PluginRegistry, rs: RevenueShare + self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin] ) -> None: - # "Adiuva" is the author of the seeded plugins - await rs.record_install("plugin-slack-notify", "u1", amount_cents=499) - await rs.record_install("plugin-slack-notify", "u2", amount_cents=499) - result = await rs.get_earnings("Adiuva") + await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["power"], amount_cents=499) + await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499) + result = await rs.get_earnings(db_session, "Adiuva") assert result["total_installs"] == 2 assert result["total_revenue_cents"] == 998 assert result["developer_share_cents"] == int(499 * 0.70) * 2 @@ -311,77 +336,67 @@ class TestRevenueShare: class TestPluginRoutes: - def test_list_plugins_requires_power_tier(self) -> None: - with TestClient(app) as client: - resp = client.get("/api/v1/plugins", headers=_auth("free")) + def test_list_plugins_requires_power_tier(self, client, seed_plugins) -> None: + resp = client.get("/api/v1/plugins", headers=auth_header("free")) assert resp.status_code == 403 - def test_list_plugins_pro_tier_blocked(self) -> None: - with TestClient(app) as client: - resp = client.get("/api/v1/plugins", headers=_auth("pro")) + def test_list_plugins_pro_tier_blocked(self, client, seed_plugins) -> None: + resp = client.get("/api/v1/plugins", headers=auth_header("pro")) assert resp.status_code == 403 - def test_list_plugins_power_tier_ok(self) -> None: - with TestClient(app) as client: - resp = client.get("/api/v1/plugins", headers=_auth("power")) + def test_list_plugins_power_tier_ok(self, client, seed_plugins) -> None: + resp = client.get("/api/v1/plugins", headers=auth_header("power")) assert resp.status_code == 200 data = resp.json() assert "plugins" in data - assert data["total"] >= 3 + assert data["total"] == 3 - def test_list_plugins_team_tier_ok(self) -> None: - with TestClient(app) as client: - resp = client.get("/api/v1/plugins", headers=_auth("team")) + def test_list_plugins_team_tier_ok(self, client, seed_plugins) -> None: + resp = client.get("/api/v1/plugins", headers=auth_header("team")) assert resp.status_code == 200 - def test_get_plugin_found(self) -> None: - with TestClient(app) as client: - resp = client.get("/api/v1/plugins/plugin-github-sync", headers=_auth()) + def test_get_plugin_found(self, client, seed_plugins) -> None: + resp = client.get("/api/v1/plugins/plugin-github-sync", headers=auth_header()) assert resp.status_code == 200 data = resp.json() assert data["plugin"]["id"] == "plugin-github-sync" assert "install_count" in data - def test_get_plugin_not_found(self) -> None: - with TestClient(app) as client: - resp = client.get("/api/v1/plugins/no-such-plugin", headers=_auth()) + def test_get_plugin_not_found(self, client, seed_plugins) -> None: + resp = client.get("/api/v1/plugins/no-such-plugin", headers=auth_header()) assert resp.status_code == 404 - def test_install_plugin_free(self) -> None: - with TestClient(app) as client: - resp = client.post( - "/api/v1/plugins/plugin-github-sync/install", - json={"plugin_id": "plugin-github-sync"}, - headers=_auth(), - ) + def test_install_plugin_free(self, client, seed_plugins) -> None: + resp = client.post( + "/api/v1/plugins/plugin-github-sync/install", + json={"plugin_id": "plugin-github-sync"}, + headers=auth_header(), + ) assert resp.status_code == 200 data = resp.json() assert data["ok"] is True assert "download_url" in data - def test_install_plugin_not_found(self) -> None: - with TestClient(app) as client: - resp = client.post( - "/api/v1/plugins/ghost/install", - json={"plugin_id": "ghost"}, - headers=_auth(), - ) + def test_install_plugin_not_found(self, client, seed_plugins) -> None: + resp = client.post( + "/api/v1/plugins/ghost/install", + json={"plugin_id": "ghost"}, + headers=auth_header(), + ) assert resp.status_code == 404 - def test_uninstall_plugin_ok(self) -> None: - with TestClient(app) as client: - resp = client.delete( - "/api/v1/plugins/plugin-github-sync/install", - headers=_auth(), - ) + def test_uninstall_plugin_ok(self, client, seed_plugins) -> None: + resp = client.delete( + "/api/v1/plugins/plugin-github-sync/install", + headers=auth_header(), + ) assert resp.status_code == 200 assert resp.json()["ok"] is True - def test_install_requires_power_tier(self) -> None: - with TestClient(app) as client: - resp = client.post( - "/api/v1/plugins/plugin-github-sync/install", - json={"plugin_id": "plugin-github-sync"}, - headers=_auth("free"), - ) + def test_install_requires_power_tier(self, client, seed_plugins) -> None: + resp = client.post( + "/api/v1/plugins/plugin-github-sync/install", + json={"plugin_id": "plugin-github-sync"}, + headers=auth_header("free"), + ) assert resp.status_code == 403