237 lines
8.1 KiB
Python
237 lines
8.1 KiB
Python
"""Shared test fixtures for database-backed tests.
|
|
|
|
Provides an async SQLite in-memory engine that auto-creates all tables,
|
|
a per-test session, and a FastAPI ``TestClient`` wired to use it.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import time
|
|
import uuid
|
|
from collections.abc import AsyncGenerator, Generator
|
|
from unittest.mock import patch
|
|
|
|
import boto3
|
|
import pytest
|
|
import pytest_asyncio
|
|
from fastapi.testclient import TestClient
|
|
from jose import jwt
|
|
from moto import mock_aws
|
|
from sqlalchemy import StaticPool, event
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
from app.config.settings import settings
|
|
from app.db import Base, get_session
|
|
from app.main import app
|
|
from app.models import Plugin, Subscription, User
|
|
|
|
# ── Fixed test user IDs (one per tier) ───────────────────────────────
|
|
|
|
TEST_USER_IDS: dict[str, str] = {
|
|
"free": "00000000-0000-0000-0000-000000000001",
|
|
"pro": "00000000-0000-0000-0000-000000000002",
|
|
"power": "00000000-0000-0000-0000-000000000003",
|
|
"team": "00000000-0000-0000-0000-000000000004",
|
|
}
|
|
|
|
# ── Async SQLite engine ──────────────────────────────────────────────
|
|
|
|
_TEST_ENGINE = create_async_engine(
|
|
"sqlite+aiosqlite://",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
|
|
_TestSessionLocal = async_sessionmaker(
|
|
_TEST_ENGINE,
|
|
expire_on_commit=False,
|
|
)
|
|
|
|
|
|
# Enable foreign key enforcement for SQLite (off by default).
|
|
@event.listens_for(_TEST_ENGINE.sync_engine, "connect")
|
|
def _set_sqlite_pragma(dbapi_conn, _connection_record): # noqa: ANN001
|
|
cursor = dbapi_conn.cursor()
|
|
cursor.execute("PRAGMA foreign_keys=ON")
|
|
cursor.close()
|
|
|
|
|
|
# ── Fixtures ─────────────────────────────────────────────────────────
|
|
|
|
@pytest_asyncio.fixture(autouse=True)
|
|
async def _create_tables():
|
|
"""Create all tables before each test, seed test users, then drop after."""
|
|
async with _TEST_ENGINE.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
# Seed one User + Subscription per tier so FK constraints and auth work.
|
|
async with _TestSessionLocal() as session:
|
|
for tier, uid in TEST_USER_IDS.items():
|
|
session.add(User(
|
|
id=uid,
|
|
email=f"{tier}@test.com",
|
|
password_hash="$2b$12$fakehashfortesting000000000000000000000000000",
|
|
tier=tier,
|
|
))
|
|
session.add(Subscription(
|
|
id=str(uuid.uuid4()),
|
|
user_id=uid,
|
|
tier=tier,
|
|
stripe_subscription_id=f"sub_test_{tier}",
|
|
status="active",
|
|
))
|
|
await session.commit()
|
|
|
|
yield
|
|
async with _TEST_ENGINE.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def db_session() -> AsyncGenerator[AsyncSession, None]:
|
|
"""Yield a per-test async DB session."""
|
|
async with _TestSessionLocal() as session:
|
|
yield session
|
|
|
|
|
|
@pytest.fixture
|
|
def client(db_session: AsyncSession) -> Generator[TestClient, None, None]: # noqa: ANN001
|
|
"""FastAPI test client with ``get_session`` overridden to use the test DB."""
|
|
|
|
async def _override_get_session() -> AsyncGenerator[AsyncSession, None]:
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_session] = _override_get_session
|
|
with TestClient(app) as c:
|
|
yield c
|
|
app.dependency_overrides.pop(get_session, None)
|
|
|
|
|
|
# ── Seed data helpers ────────────────────────────────────────────────
|
|
|
|
_SEED_PLUGINS = [
|
|
Plugin(
|
|
id="plugin-github-sync",
|
|
name="GitHub Sync",
|
|
description="Sync tasks with GitHub Issues and pull requests.",
|
|
version="1.0.0",
|
|
author_name="Adiuva",
|
|
category="productivity",
|
|
price_cents=0,
|
|
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
|
status="approved",
|
|
s3_package_key="plugins/plugin-github-sync/1.0.0/package.zip",
|
|
install_count=0,
|
|
avg_rating=0.0,
|
|
),
|
|
Plugin(
|
|
id="plugin-slack-notify",
|
|
name="Slack Notifier",
|
|
description="Post task and checkpoint updates to Slack channels.",
|
|
version="1.2.0",
|
|
author_name="Adiuva",
|
|
category="communication",
|
|
price_cents=499,
|
|
permissions=json.dumps(["read:tasks", "read:checkpoints"]),
|
|
status="approved",
|
|
s3_package_key="plugins/plugin-slack-notify/1.2.0/package.zip",
|
|
install_count=0,
|
|
avg_rating=0.0,
|
|
),
|
|
Plugin(
|
|
id="plugin-time-tracker",
|
|
name="Time Tracker",
|
|
description="Track time spent on tasks with automatic reporting.",
|
|
version="0.9.1",
|
|
author_name="Third Party",
|
|
category="productivity",
|
|
price_cents=999,
|
|
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
|
status="approved",
|
|
s3_package_key="plugins/plugin-time-tracker/0.9.1/package.zip",
|
|
install_count=0,
|
|
avg_rating=0.0,
|
|
),
|
|
]
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def seed_plugins(db_session: AsyncSession) -> list[Plugin]:
|
|
"""Insert the 3 default approved plugins and return them."""
|
|
plugins = []
|
|
for template in _SEED_PLUGINS:
|
|
p = Plugin(
|
|
id=template.id,
|
|
name=template.name,
|
|
description=template.description,
|
|
version=template.version,
|
|
author_name=template.author_name,
|
|
category=template.category,
|
|
price_cents=template.price_cents,
|
|
permissions=template.permissions,
|
|
status=template.status,
|
|
s3_package_key=template.s3_package_key,
|
|
install_count=template.install_count,
|
|
avg_rating=template.avg_rating,
|
|
)
|
|
db_session.add(p)
|
|
plugins.append(p)
|
|
await db_session.commit()
|
|
return plugins
|
|
|
|
|
|
# ── JWT helpers ──────────────────────────────────────────────────────
|
|
|
|
|
|
def make_jwt(
|
|
tier: str = "power",
|
|
user_id: str | None = None,
|
|
email: str | None = None,
|
|
) -> str:
|
|
"""Create a signed test JWT.
|
|
|
|
Uses the fixed ``TEST_USER_IDS`` mapping so the auth middleware can
|
|
find the corresponding ``Subscription`` row in the test database.
|
|
"""
|
|
uid = user_id or TEST_USER_IDS.get(tier, str(uuid.uuid4()))
|
|
now = int(time.time())
|
|
payload = {
|
|
"sub": uid,
|
|
"email": email or f"{tier}@test.com",
|
|
"tier": tier,
|
|
"exp": now + 3600,
|
|
"iat": now,
|
|
}
|
|
return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
|
|
|
|
|
def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, str]:
|
|
"""Return an Authorization header dict for the given tier."""
|
|
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
|
|
|
|
|
# ── S3 mock fixture ──────────────────────────────────────────────────
|
|
|
|
S3_TEST_BUCKET = "test-bucket"
|
|
S3_TEST_REGION = "us-east-1"
|
|
|
|
|
|
@pytest.fixture
|
|
def s3_bucket():
|
|
"""Create a mocked S3 bucket via moto and patch BlobStore settings."""
|
|
with mock_aws():
|
|
os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing")
|
|
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing")
|
|
os.environ.setdefault("AWS_DEFAULT_REGION", S3_TEST_REGION)
|
|
client = boto3.client("s3", region_name=S3_TEST_REGION)
|
|
client.create_bucket(Bucket=S3_TEST_BUCKET)
|
|
with patch("app.storage.blob_store.settings") as mock_settings:
|
|
mock_settings.S3_BUCKET = S3_TEST_BUCKET
|
|
mock_settings.S3_REGION = S3_TEST_REGION
|
|
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
|
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
|
yield S3_TEST_BUCKET
|