"""Shared test fixtures for database-backed tests. Provides an async SQLite in-memory engine that auto-creates all tables, a per-test session, and a FastAPI ``TestClient`` wired to use it. """ from __future__ import annotations import json import os import time import uuid from collections.abc import AsyncGenerator, Generator from unittest.mock import patch import boto3 import pytest import pytest_asyncio from fastapi.testclient import TestClient from jose import jwt from moto import mock_aws from sqlalchemy import StaticPool, event from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from app.config.settings import settings from app.db import Base, get_session from app.main import app from app.models import Plugin, Subscription, User # ── Fixed test user IDs (one per tier) ─────────────────────────────── TEST_USER_IDS: dict[str, str] = { "free": "00000000-0000-0000-0000-000000000001", "pro": "00000000-0000-0000-0000-000000000002", "power": "00000000-0000-0000-0000-000000000003", "team": "00000000-0000-0000-0000-000000000004", } # ── Async SQLite engine ────────────────────────────────────────────── _TEST_ENGINE = create_async_engine( "sqlite+aiosqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) _TestSessionLocal = async_sessionmaker( _TEST_ENGINE, expire_on_commit=False, ) # Enable foreign key enforcement for SQLite (off by default). @event.listens_for(_TEST_ENGINE.sync_engine, "connect") def _set_sqlite_pragma(dbapi_conn, _connection_record): # noqa: ANN001 cursor = dbapi_conn.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() # ── Fixtures ───────────────────────────────────────────────────────── @pytest_asyncio.fixture(autouse=True) async def _create_tables(): """Create all tables before each test, seed test users, then drop after.""" async with _TEST_ENGINE.begin() as conn: await conn.run_sync(Base.metadata.create_all) # Seed one User + Subscription per tier so FK constraints and auth work. async with _TestSessionLocal() as session: for tier, uid in TEST_USER_IDS.items(): session.add(User( id=uid, email=f"{tier}@test.com", password_hash="$2b$12$fakehashfortesting000000000000000000000000000", tier=tier, )) session.add(Subscription( id=str(uuid.uuid4()), user_id=uid, tier=tier, stripe_subscription_id=f"sub_test_{tier}", status="active", )) await session.commit() yield async with _TEST_ENGINE.begin() as conn: await conn.run_sync(Base.metadata.drop_all) @pytest_asyncio.fixture async def db_session() -> AsyncGenerator[AsyncSession, None]: """Yield a per-test async DB session.""" async with _TestSessionLocal() as session: yield session @pytest.fixture def client(db_session: AsyncSession) -> Generator[TestClient, None, None]: # noqa: ANN001 """FastAPI test client with ``get_session`` overridden to use the test DB.""" async def _override_get_session() -> AsyncGenerator[AsyncSession, None]: yield db_session app.dependency_overrides[get_session] = _override_get_session with TestClient(app) as c: yield c app.dependency_overrides.pop(get_session, None) # ── Seed data helpers ──────────────────────────────────────────────── _SEED_PLUGINS = [ Plugin( id="plugin-github-sync", name="GitHub Sync", description="Sync tasks with GitHub Issues and pull requests.", version="1.0.0", author_name="Adiuva", category="productivity", price_cents=0, permissions=json.dumps(["read:tasks", "write:tasks"]), status="approved", s3_package_key="plugins/plugin-github-sync/1.0.0/package.zip", install_count=0, avg_rating=0.0, ), Plugin( id="plugin-slack-notify", name="Slack Notifier", description="Post task and checkpoint updates to Slack channels.", version="1.2.0", author_name="Adiuva", category="communication", price_cents=499, permissions=json.dumps(["read:tasks", "read:checkpoints"]), status="approved", s3_package_key="plugins/plugin-slack-notify/1.2.0/package.zip", install_count=0, avg_rating=0.0, ), Plugin( id="plugin-time-tracker", name="Time Tracker", description="Track time spent on tasks with automatic reporting.", version="0.9.1", author_name="Third Party", category="productivity", price_cents=999, permissions=json.dumps(["read:tasks", "write:tasks"]), status="approved", s3_package_key="plugins/plugin-time-tracker/0.9.1/package.zip", install_count=0, avg_rating=0.0, ), ] @pytest_asyncio.fixture async def seed_plugins(db_session: AsyncSession) -> list[Plugin]: """Insert the 3 default approved plugins and return them.""" plugins = [] for template in _SEED_PLUGINS: p = Plugin( id=template.id, name=template.name, description=template.description, version=template.version, author_name=template.author_name, category=template.category, price_cents=template.price_cents, permissions=template.permissions, status=template.status, s3_package_key=template.s3_package_key, install_count=template.install_count, avg_rating=template.avg_rating, ) db_session.add(p) plugins.append(p) await db_session.commit() return plugins # ── JWT helpers ────────────────────────────────────────────────────── def make_jwt( tier: str = "power", user_id: str | None = None, email: str | None = None, ) -> str: """Create a signed test JWT. Uses the fixed ``TEST_USER_IDS`` mapping so the auth middleware can find the corresponding ``Subscription`` row in the test database. """ uid = user_id or TEST_USER_IDS.get(tier, str(uuid.uuid4())) now = int(time.time()) payload = { "sub": uid, "email": email or f"{tier}@test.com", "tier": tier, "exp": now + 3600, "iat": now, } return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM) def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, str]: """Return an Authorization header dict for the given tier.""" return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"} # ── S3 mock fixture ────────────────────────────────────────────────── S3_TEST_BUCKET = "test-bucket" S3_TEST_REGION = "us-east-1" @pytest.fixture def s3_bucket(): """Create a mocked S3 bucket via moto and patch BlobStore settings.""" with mock_aws(): os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing") os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing") os.environ.setdefault("AWS_DEFAULT_REGION", S3_TEST_REGION) client = boto3.client("s3", region_name=S3_TEST_REGION) client.create_bucket(Bucket=S3_TEST_BUCKET) with patch("app.storage.blob_store.settings") as mock_settings: mock_settings.S3_BUCKET = S3_TEST_BUCKET mock_settings.S3_REGION = S3_TEST_REGION mock_settings.AWS_ACCESS_KEY_ID = "testing" mock_settings.AWS_SECRET_ACCESS_KEY = "testing" yield S3_TEST_BUCKET