110 lines
3.1 KiB
Python
110 lines
3.1 KiB
Python
import pytest
|
|
import pytest_asyncio
|
|
from httpx import ASGITransport, AsyncClient
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
from app.main import app
|
|
from app.models import Base
|
|
from app.db import get_db
|
|
from app.rate_limit import _hits_store
|
|
|
|
# Use SQLite for tests (no Postgres dependency)
|
|
TEST_DB_URL = "sqlite+aiosqlite:///./test_waitlist.db"
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def db_session():
|
|
engine = create_async_engine(TEST_DB_URL)
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
|
async with session_factory() as session:
|
|
yield session
|
|
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def client(db_session):
|
|
async def _override_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_db] = _override_db
|
|
|
|
# Reset rate limiter state between tests
|
|
_hits_store.clear()
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
yield ac
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_join_waitlist_success(client):
|
|
resp = await client.post(
|
|
"/api/v1/waitlist",
|
|
json={"email": "user@example.com"},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["ok"] is True
|
|
assert "list" in data["message"].lower()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_duplicate_email_is_idempotent(client):
|
|
payload = {"email": "dup@example.com"}
|
|
r1 = await client.post("/api/v1/waitlist", json=payload)
|
|
r2 = await client.post("/api/v1/waitlist", json=payload)
|
|
assert r1.status_code == 200
|
|
assert r2.status_code == 200
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_email_rejected(client):
|
|
resp = await client.post(
|
|
"/api/v1/waitlist",
|
|
json={"email": "not-an-email"},
|
|
)
|
|
assert resp.status_code == 422
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_honeypot_silently_succeeds(client):
|
|
resp = await client.post(
|
|
"/api/v1/waitlist",
|
|
json={"email": "bot@spam.com", "website": "http://spam.site"},
|
|
)
|
|
# Honeypot field filled → validation error (max_length=0)
|
|
assert resp.status_code == 422
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_email_rejected(client):
|
|
resp = await client.post("/api/v1/waitlist", json={})
|
|
assert resp.status_code == 422
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_endpoint(client):
|
|
resp = await client.get("/health")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["status"] == "ok"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rate_limit(client):
|
|
"""Submit more than the per-minute limit and expect 429."""
|
|
for i in range(6):
|
|
resp = await client.post(
|
|
"/api/v1/waitlist",
|
|
json={"email": f"rate{i}@example.com"},
|
|
)
|
|
# The 6th request should be rate-limited (limit is 5)
|
|
assert resp.status_code == 429
|
|
assert "Retry-After" in resp.headers
|