"""Shared test fixtures for database-backed tests. Provides an async SQLite in-memory engine that auto-creates all tables, a per-test session, and a FastAPI ``TestClient`` wired to use it. """ from __future__ import annotations import time import uuid from collections.abc import AsyncGenerator, Generator import pytest import pytest_asyncio from fastapi.testclient import TestClient from jose import jwt from sqlalchemy import StaticPool, event from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy import select from app.config.settings import settings from app.db import Base, get_session from app.main import app from app.models import Subscription, User # ── Fixed test user IDs (one per tier) ─────────────────────────────── TEST_USER_IDS: dict[str, str] = { "free": "00000000-0000-0000-0000-000000000001", "pro": "00000000-0000-0000-0000-000000000002", "power": "00000000-0000-0000-0000-000000000003", "team": "00000000-0000-0000-0000-000000000004", } # ── Async SQLite engine ────────────────────────────────────────────── _TEST_ENGINE = create_async_engine( "sqlite+aiosqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) _TestSessionLocal = async_sessionmaker( _TEST_ENGINE, expire_on_commit=False, ) # Enable foreign key enforcement for SQLite (off by default). @event.listens_for(_TEST_ENGINE.sync_engine, "connect") def _set_sqlite_pragma(dbapi_conn, _connection_record): # noqa: ANN001 cursor = dbapi_conn.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() # ── Fixtures ───────────────────────────────────────────────────────── @pytest_asyncio.fixture(autouse=True) async def _create_tables(): """Create all tables before each test, seed test users, then drop after.""" async with _TEST_ENGINE.begin() as conn: await conn.run_sync(Base.metadata.create_all) # Seed one User + Subscription per tier so FK constraints and auth work. async with _TestSessionLocal() as session: for tier, uid in TEST_USER_IDS.items(): session.add(User( id=uid, email=f"{tier}@test.com", password_hash="$2b$12$fakehashfortesting000000000000000000000000000", tier=tier, )) session.add(Subscription( id=str(uuid.uuid4()), user_id=uid, tier=tier, stripe_subscription_id=f"sub_test_{tier}", status="active", )) await session.commit() yield async with _TEST_ENGINE.begin() as conn: await conn.run_sync(Base.metadata.drop_all) @pytest_asyncio.fixture async def db_session() -> AsyncGenerator[AsyncSession, None]: """Yield a per-test async DB session.""" async with _TestSessionLocal() as session: yield session @pytest.fixture def client(db_session: AsyncSession) -> Generator[TestClient, None, None]: # noqa: ANN001 """FastAPI test client with ``get_session`` overridden to use the test DB.""" async def _override_get_session() -> AsyncGenerator[AsyncSession, None]: yield db_session app.dependency_overrides[get_session] = _override_get_session with TestClient(app) as c: yield c app.dependency_overrides.pop(get_session, None) # ── JWT helpers ────────────────────────────────────────────────────── def make_jwt( tier: str = "power", user_id: str | None = None, email: str | None = None, ) -> str: """Create a signed test JWT. Uses the fixed ``TEST_USER_IDS`` mapping so the auth middleware can find the corresponding ``Subscription`` row in the test database. """ uid = user_id or TEST_USER_IDS.get(tier, str(uuid.uuid4())) now = int(time.time()) payload = { "sub": uid, "email": email or f"{tier}@test.com", "tier": tier, "exp": now + 3600, "iat": now, } return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM) def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, str]: """Return an Authorization header dict for the given tier.""" return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"} # ── Convenience aliases and per-tier user fixtures ──────────────────── @pytest_asyncio.fixture async def db(db_session: AsyncSession) -> AsyncSession: """Alias for db_session — used by folder quota tests.""" return db_session @pytest_asyncio.fixture async def test_user_free(db_session: AsyncSession): """Return the seeded free-tier User row.""" result = await db_session.execute( select(User).where(User.id == TEST_USER_IDS["free"]) ) return result.scalar_one() @pytest_asyncio.fixture async def test_user_power(db_session: AsyncSession): """Return the seeded power-tier User row.""" result = await db_session.execute( select(User).where(User.id == TEST_USER_IDS["power"]) ) return result.scalar_one() @pytest.fixture def auth_headers_free() -> dict[str, str]: """Authorization header for the seeded free-tier user.""" return auth_header("free") # ── CLI options ─────────────────────────────────────────────────────── def pytest_addoption(parser): parser.addoption( "--preprocess-dir", default=None, help="Override fixture folder for preprocessor tests (must contain cases.yaml + data/)", ) parser.addoption( "--runner-dir", default=None, help="Override fixture folder for agent_runner_v2 eval tests (must contain cases.yaml + data/)", ) parser.addoption( "--journey-dir", default=None, help="Override fixture folder for journey_v2 eval tests (must contain cases.yaml + data/)", )