227 lines
7.2 KiB
Python
227 lines
7.2 KiB
Python
import time
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from httpx import ASGITransport, AsyncClient
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
from app.main import app
|
|
from app.models import Base, WaitlistEntry
|
|
from app.db import get_db
|
|
from app.rate_limit import _hits_store
|
|
from app.token import generate_token, verify_token
|
|
|
|
# Use SQLite for tests (no Postgres dependency)
|
|
TEST_DB_URL = "sqlite+aiosqlite:///./test_waitlist.db"
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def db_session():
|
|
engine = create_async_engine(TEST_DB_URL)
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
|
async with session_factory() as session:
|
|
yield session
|
|
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def client(db_session):
|
|
async def _override_db():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_db] = _override_db
|
|
|
|
# Reset rate limiter state between tests
|
|
_hits_store.clear()
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
yield ac
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_join_waitlist_success(client):
|
|
resp = await client.post(
|
|
"/api/v1/waitlist",
|
|
json={"email": "user@example.com"},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["ok"] is True
|
|
assert "inbox" in data["message"].lower()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_duplicate_email_is_idempotent(client):
|
|
payload = {"email": "dup@example.com"}
|
|
r1 = await client.post("/api/v1/waitlist", json=payload)
|
|
r2 = await client.post("/api/v1/waitlist", json=payload)
|
|
assert r1.status_code == 200
|
|
assert r2.status_code == 200
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_email_rejected(client):
|
|
resp = await client.post(
|
|
"/api/v1/waitlist",
|
|
json={"email": "not-an-email"},
|
|
)
|
|
assert resp.status_code == 422
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_honeypot_silently_succeeds(client):
|
|
resp = await client.post(
|
|
"/api/v1/waitlist",
|
|
json={"email": "bot@spam.com", "website": "http://spam.site"},
|
|
)
|
|
# Honeypot field filled → validation error (max_length=0)
|
|
assert resp.status_code == 422
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_email_rejected(client):
|
|
resp = await client.post("/api/v1/waitlist", json={})
|
|
assert resp.status_code == 422
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_endpoint(client):
|
|
resp = await client.get("/health")
|
|
assert resp.status_code == 200
|
|
assert resp.json()["status"] == "ok"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rate_limit(client):
|
|
"""Submit more than the per-minute limit and expect 429."""
|
|
for i in range(6):
|
|
resp = await client.post(
|
|
"/api/v1/waitlist",
|
|
json={"email": f"rate{i}@example.com"},
|
|
)
|
|
# The 6th request should be rate-limited (limit is 5)
|
|
assert resp.status_code == 429
|
|
assert "Retry-After" in resp.headers
|
|
|
|
|
|
# ── Confirmation token tests ─────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_token_roundtrip():
|
|
"""A generated token should verify back to the same email."""
|
|
email = "token@example.com"
|
|
token = generate_token(email)
|
|
assert verify_token(token) == email
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_token_expired():
|
|
"""An expired token should return None."""
|
|
email = "expired@example.com"
|
|
with patch("app.token.time") as mock_time:
|
|
# Generate token "49 hours ago"
|
|
past = time.time() - 49 * 3600
|
|
mock_time.time.return_value = past
|
|
token = generate_token(email)
|
|
|
|
# Now verify with real time — should be expired (>48h)
|
|
assert verify_token(token) is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_token_tampered():
|
|
"""A tampered token should return None."""
|
|
token = generate_token("legit@example.com")
|
|
# Flip a character in the token
|
|
tampered = token[:-1] + ("A" if token[-1] != "A" else "B")
|
|
assert verify_token(tampered) is None
|
|
|
|
|
|
# ── Confirm endpoint tests ───────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_confirm_valid_token(client, db_session):
|
|
"""GET /confirm with valid token marks email as confirmed."""
|
|
# Seed an unconfirmed entry
|
|
entry = WaitlistEntry(email="confirm@example.com", source="website")
|
|
db_session.add(entry)
|
|
await db_session.commit()
|
|
|
|
token = generate_token("confirm@example.com")
|
|
resp = await client.get(f"/api/v1/waitlist/confirm?token={token}")
|
|
assert resp.status_code == 200
|
|
assert "confirmed" in resp.text.lower() or "verified" in resp.text.lower()
|
|
|
|
# Verify DB state
|
|
result = await db_session.execute(
|
|
select(WaitlistEntry).where(WaitlistEntry.email == "confirm@example.com")
|
|
)
|
|
assert result.scalar_one().confirmed is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_confirm_invalid_token(client):
|
|
"""GET /confirm with invalid token returns 400."""
|
|
resp = await client.get("/api/v1/waitlist/confirm?token=garbage")
|
|
assert resp.status_code == 400
|
|
assert "invalid" in resp.text.lower() or "expired" in resp.text.lower()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_confirm_idempotent(client, db_session):
|
|
"""Confirming an already confirmed email returns 200 (idempotent)."""
|
|
entry = WaitlistEntry(email="idem@example.com", source="website", confirmed=True)
|
|
db_session.add(entry)
|
|
await db_session.commit()
|
|
|
|
token = generate_token("idem@example.com")
|
|
resp = await client.get(f"/api/v1/waitlist/confirm?token={token}")
|
|
assert resp.status_code == 200
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_confirm_unknown_email(client):
|
|
"""Token for a non-existent email returns 400."""
|
|
token = generate_token("unknown@example.com")
|
|
resp = await client.get(f"/api/v1/waitlist/confirm?token={token}")
|
|
assert resp.status_code == 400
|
|
|
|
|
|
# ── Brevo integration tests (mocked) ────────────────────────────────
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_signup_triggers_confirmation_email(client, db_session):
|
|
"""When Brevo is configured, signup sends a confirmation email."""
|
|
with patch("app.routes.settings") as mock_settings, \
|
|
patch("app.routes.send_confirmation_email", new_callable=AsyncMock) as mock_send:
|
|
mock_settings.brevo_configured = True
|
|
mock_settings.CONFIRM_BASE_URL = "http://test"
|
|
|
|
resp = await client.post(
|
|
"/api/v1/waitlist",
|
|
json={"email": "brevo@example.com"},
|
|
)
|
|
assert resp.status_code == 200
|
|
|
|
# Wait for fire-and-forget task
|
|
import asyncio
|
|
await asyncio.sleep(0.1)
|
|
|
|
mock_send.assert_called_once()
|
|
call_args = mock_send.call_args
|
|
assert call_args[0][0] == "brevo@example.com"
|
|
assert "confirm" in call_args[0][1]
|