Initial commit: waitlist microservice
This commit is contained in:
11
.env.example
Normal file
11
.env.example
Normal file
@@ -0,0 +1,11 @@
|
||||
# Database
|
||||
DATABASE_URL=postgresql+asyncpg://waitlist:changeme@localhost:5432/waitlist_db
|
||||
|
||||
# CORS — comma-separated allowed origins
|
||||
ALLOWED_ORIGINS=https://adiuvai.com,https://www.adiuvai.com
|
||||
|
||||
# Rate limiting
|
||||
RATE_LIMIT_PER_MINUTE=5
|
||||
|
||||
# Set to "production" in prod to enforce strict origin checks
|
||||
ENVIRONMENT=development
|
||||
97
.gitea/workflows/deploy.yaml
Normal file
97
.gitea/workflows/deploy.yaml
Normal file
@@ -0,0 +1,97 @@
|
||||
name: Test & Deploy Waitlist
|
||||
run-name: ${{ gitea.ref_name }} → Docker LXC
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
jobs:
|
||||
# ── 1. Run tests in an isolated Python container ──────────────────
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: python:3.12-slim
|
||||
|
||||
steps:
|
||||
- name: Install git
|
||||
run: apt-get update && apt-get install -y --no-install-recommends git
|
||||
|
||||
- name: Checkout Code
|
||||
run: |
|
||||
git clone --depth 1 --branch "${GITHUB_REF_NAME}" \
|
||||
"http://10.0.0.119:3000/${GITHUB_REPOSITORY}.git" . || \
|
||||
git clone --depth 1 "http://10.0.0.119:3000/${GITHUB_REPOSITORY}.git" . && \
|
||||
git checkout "${GITHUB_SHA}"
|
||||
|
||||
- name: Install Dependencies
|
||||
run: pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
- name: Install Test Dependencies
|
||||
run: pip install --no-cache-dir pytest pytest-asyncio httpx aiosqlite
|
||||
|
||||
- name: Run Linter
|
||||
run: ruff check app/ tests/
|
||||
|
||||
- name: Run Tests
|
||||
run: pytest tests/ -v --tb=short
|
||||
|
||||
# ── 2. Deploy to Docker LXC via SSH ─────────────────────────────────
|
||||
deploy:
|
||||
needs: test
|
||||
runs-on: ubuntu-latest
|
||||
if: gitea.event_name == 'push'
|
||||
|
||||
steps:
|
||||
- name: Deploy via SSH
|
||||
uses: appleboy/ssh-action@v1.0.0
|
||||
with:
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_KEY }}
|
||||
script: |
|
||||
set -e
|
||||
DEPLOY_DIR="/opt/adiuvai-waitlist"
|
||||
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
|
||||
TAG="${{ gitea.ref_name }}"
|
||||
|
||||
# ── Pull latest code ──
|
||||
cd /tmp && rm -rf adiuvai-waitlist-deploy
|
||||
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuvai-waitlist-deploy
|
||||
|
||||
# ── Sync source (preserve .env) ──
|
||||
cp -rf /tmp/adiuvai-waitlist-deploy/app/ \
|
||||
/tmp/adiuvai-waitlist-deploy/alembic/ \
|
||||
/tmp/adiuvai-waitlist-deploy/alembic.ini \
|
||||
/tmp/adiuvai-waitlist-deploy/Dockerfile \
|
||||
/tmp/adiuvai-waitlist-deploy/docker-compose.yml \
|
||||
/tmp/adiuvai-waitlist-deploy/requirements.txt \
|
||||
"$DEPLOY_DIR/"
|
||||
rm -rf /tmp/adiuvai-waitlist-deploy
|
||||
|
||||
# ── Verify .env ──
|
||||
if [ ! -f "$DEPLOY_DIR/.env" ]; then
|
||||
echo "❌ $DEPLOY_DIR/.env not found. Create it before deploying."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# ── Build & restart (app only — keep DB running) ──
|
||||
cd "$DEPLOY_DIR"
|
||||
docker compose up -d --build --no-deps db # ensure DB is running
|
||||
docker compose build app # rebuild app image
|
||||
docker compose up -d --no-deps app # restart only app container
|
||||
|
||||
# ── Migrations ──
|
||||
docker compose exec -T app alembic upgrade head
|
||||
|
||||
# ── Health check ──
|
||||
echo "Waiting for app..."
|
||||
sleep 5
|
||||
HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8001/health)
|
||||
if [ "$HTTP_CODE" -eq 200 ]; then
|
||||
echo "✅ Waitlist service is healthy (HTTP ${HTTP_CODE})"
|
||||
else
|
||||
echo "❌ Health check failed (HTTP ${HTTP_CODE})"
|
||||
docker compose logs app --tail=50
|
||||
exit 1
|
||||
fi
|
||||
31
.gitignore
vendored
Normal file
31
.gitignore
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Virtual environment
|
||||
.venv/
|
||||
venv/
|
||||
env/
|
||||
|
||||
# Environment variables
|
||||
.env
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# Testing / coverage
|
||||
.pytest_cache/
|
||||
htmlcov/
|
||||
.coverage
|
||||
test_waitlist.db
|
||||
|
||||
# Docker
|
||||
*.log
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
19
Dockerfile
Normal file
19
Dockerfile
Normal file
@@ -0,0 +1,19 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install only psycopg2 build deps (needed for alembic sync driver)
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends libpq-dev gcc && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt psycopg2-binary
|
||||
|
||||
COPY alembic.ini .
|
||||
COPY alembic/ alembic/
|
||||
COPY app/ app/
|
||||
|
||||
EXPOSE 8001
|
||||
|
||||
CMD ["sh", "-c", "alembic upgrade head && gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 2 -b 0.0.0.0:8001 --timeout 30"]
|
||||
69
README.md
Normal file
69
README.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# adiuvAI Waitlist Service
|
||||
|
||||
Minimal FastAPI microservice that stores waitlist email signups in PostgreSQL.
|
||||
|
||||
## Security
|
||||
|
||||
Designed to sit behind **Cloudflare** (WAF + DDoS protection). Additional hardening:
|
||||
|
||||
| Layer | What |
|
||||
|-------|------|
|
||||
| **Cloudflare** | WAF, bot management, DDoS mitigation (external) |
|
||||
| **Rate limiter** | 5 req/min per IP, Cloudflare-aware (`CF-Connecting-IP`) |
|
||||
| **Origin validation** | Rejects POST without valid `Origin`/`Referer` in production |
|
||||
| **CORS** | Locked to `adiuvai.com` origins only |
|
||||
| **Honeypot field** | Hidden `website` field — bots that fill it get a silent 200 |
|
||||
| **Request size limit** | 4 KB max body (email payload is ~100 bytes) |
|
||||
| **Input validation** | Pydantic `EmailStr` with normalization |
|
||||
| **SQL injection** | SQLAlchemy parameterized queries (no raw SQL) |
|
||||
| **No PII leakage** | Errors return generic messages, no email reflection |
|
||||
| **Docs disabled in prod** | `/docs` and `/openapi.json` only in development |
|
||||
| **Idempotent** | Duplicate emails return success (no enumeration) |
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# 1. Start Postgres + app
|
||||
docker compose up --build
|
||||
|
||||
# 2. Test
|
||||
curl -X POST https://waitlist.adiuvai.com/api/v1/waitlist \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"email": "user@example.com"}'
|
||||
```
|
||||
|
||||
## Local Development
|
||||
|
||||
```bash
|
||||
cd waitlist
|
||||
python -m venv .venv
|
||||
.venv\Scripts\Activate.ps1 # Windows
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Copy and edit .env
|
||||
cp .env.example .env
|
||||
|
||||
# Run migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Start dev server
|
||||
uvicorn app.main:app --reload --port 8001
|
||||
|
||||
# Run tests
|
||||
pip install pytest pytest-asyncio httpx aiosqlite
|
||||
pytest tests/ -v
|
||||
```
|
||||
|
||||
## Deployment (Cloudflare)
|
||||
|
||||
1. Point `waitlist.adiuvai.com` to your server via Cloudflare DNS (orange cloud ON)
|
||||
2. Set environment variables (see `.env.example`)
|
||||
3. `docker compose up -d`
|
||||
4. Cloudflare handles TLS termination, bot filtering, and rate limiting at the edge
|
||||
|
||||
### Recommended Cloudflare Settings
|
||||
|
||||
- **WAF**: Enable managed rulesets (OWASP Core)
|
||||
- **Bot Fight Mode**: ON
|
||||
- **Rate Limiting Rule**: 10 req/10s to `/api/v1/waitlist` (defense in depth)
|
||||
- **SSL mode**: Full (Strict)
|
||||
36
alembic.ini
Normal file
36
alembic.ini
Normal file
@@ -0,0 +1,36 @@
|
||||
[alembic]
|
||||
script_location = alembic
|
||||
sqlalchemy.url = postgresql+psycopg2://waitlist:changeme@localhost:5432/waitlist_db
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
46
alembic/env.py
Normal file
46
alembic/env.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from app.models import Base
|
||||
|
||||
config = context.config
|
||||
|
||||
# Override URL from env if available (so .env takes precedence over alembic.ini)
|
||||
db_url = os.environ.get("DATABASE_URL", "")
|
||||
if db_url:
|
||||
# Alembic needs the sync driver
|
||||
sync_url = db_url.replace("+asyncpg", "+psycopg2")
|
||||
config.set_main_option("sqlalchemy.url", sync_url)
|
||||
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(url=url, target_metadata=target_metadata, literal_binds=True)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
25
alembic/script.py.mako
Normal file
25
alembic/script.py.mako
Normal file
@@ -0,0 +1,25 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
36
alembic/versions/001_create_waitlist_entries.py
Normal file
36
alembic/versions/001_create_waitlist_entries.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""create waitlist_entries table
|
||||
|
||||
Revision ID: 001
|
||||
Revises:
|
||||
Create Date: 2026-04-11
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "001"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"waitlist_entries",
|
||||
sa.Column("id", sa.BigInteger(), autoincrement=True, primary_key=True),
|
||||
sa.Column("email", sa.String(320), nullable=False, unique=True, index=True),
|
||||
sa.Column("ip_address", sa.String(45), nullable=True),
|
||||
sa.Column("source", sa.String(64), nullable=True),
|
||||
sa.Column("confirmed", sa.Boolean(), server_default="false", nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("waitlist_entries")
|
||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
21
app/config.py
Normal file
21
app/config.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
DATABASE_URL: str = "postgresql+asyncpg://waitlist:changeme@localhost:5432/waitlist_db"
|
||||
ALLOWED_ORIGINS: str = "https://adiuvai.com,https://www.adiuvai.com"
|
||||
RATE_LIMIT_PER_MINUTE: int = 5
|
||||
ENVIRONMENT: str = "development"
|
||||
|
||||
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
||||
|
||||
@property
|
||||
def origins_list(self) -> list[str]:
|
||||
return [o.strip() for o in self.ALLOWED_ORIGINS.split(",") if o.strip()]
|
||||
|
||||
@property
|
||||
def sync_database_url(self) -> str:
|
||||
return self.DATABASE_URL.replace("+asyncpg", "+psycopg2")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
17
app/db.py
Normal file
17
app/db.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.config import settings
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession:
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
59
app/main.py
Normal file
59
app/main.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import logging
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.config import settings
|
||||
from app.rate_limit import RateLimiter
|
||||
from app.routes import router
|
||||
from app.security import OriginValidator, RequestSizeLimiter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Waitlist service starting (env=%s)", settings.ENVIRONMENT)
|
||||
yield
|
||||
logger.info("Waitlist service shutting down")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="adiuvAI Waitlist",
|
||||
version="1.0.0",
|
||||
docs_url="/docs" if settings.ENVIRONMENT != "production" else None,
|
||||
redoc_url=None,
|
||||
openapi_url="/openapi.json" if settings.ENVIRONMENT != "production" else None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# ── Middleware stack (outermost runs first) ──────────────────────────
|
||||
|
||||
# 1. CORS — locked to allowed origins
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.origins_list,
|
||||
allow_methods=["POST", "OPTIONS"],
|
||||
allow_headers=["Content-Type"],
|
||||
allow_credentials=False,
|
||||
max_age=86400,
|
||||
)
|
||||
|
||||
# 2. Rate limiter (per-IP, Cloudflare-aware)
|
||||
app.add_middleware(RateLimiter)
|
||||
|
||||
# 3. Origin / Referer validation (production only)
|
||||
app.add_middleware(OriginValidator)
|
||||
|
||||
# 4. Request body size limit (4 KB)
|
||||
app.add_middleware(RequestSizeLimiter)
|
||||
|
||||
# ── Routes ───────────────────────────────────────────────────────────
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
28
app/models.py
Normal file
28
app/models.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import datetime
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import BigInteger, Boolean, DateTime, String
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class WaitlistEntry(Base):
|
||||
__tablename__ = "waitlist_entries"
|
||||
|
||||
id: Mapped[int] = mapped_column(
|
||||
BigInteger().with_variant(sa.Integer, "sqlite"),
|
||||
primary_key=True,
|
||||
autoincrement=True,
|
||||
)
|
||||
email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True)
|
||||
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
||||
source: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
confirmed: Mapped[bool] = mapped_column(Boolean, default=False, server_default=sa.text("0"))
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=sa.text("CURRENT_TIMESTAMP"),
|
||||
nullable=False,
|
||||
)
|
||||
76
app/rate_limit.py
Normal file
76
app/rate_limit.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
IP-based sliding-window rate limiter.
|
||||
|
||||
Cloudflare-aware: uses CF-Connecting-IP → X-Forwarded-For → client.host
|
||||
to identify the real client IP.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from threading import Lock
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
"""Extract real client IP behind Cloudflare / reverse proxy."""
|
||||
# Cloudflare always sets this when proxying
|
||||
cf_ip = request.headers.get("cf-connecting-ip")
|
||||
if cf_ip:
|
||||
return cf_ip.strip()
|
||||
|
||||
# Fallback: first entry in X-Forwarded-For (set by most reverse proxies)
|
||||
xff = request.headers.get("x-forwarded-for")
|
||||
if xff:
|
||||
return xff.split(",")[0].strip()
|
||||
|
||||
# Last resort: direct connection IP
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
# Module-level hits store so tests can clear it
|
||||
_hits_store: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
|
||||
class RateLimiter(BaseHTTPMiddleware):
|
||||
"""
|
||||
Sliding-window rate limiter keyed on client IP.
|
||||
|
||||
Only applies to POST /api/v1/waitlist.
|
||||
Returns 429 with Retry-After header when exceeded.
|
||||
"""
|
||||
|
||||
def __init__(self, app, per_minute: int = settings.RATE_LIMIT_PER_MINUTE):
|
||||
super().__init__(app)
|
||||
self.per_minute = per_minute
|
||||
self.window = 60 # seconds
|
||||
self._hits = _hits_store
|
||||
self._lock = Lock()
|
||||
|
||||
def _prune(self, ip: str, now: float) -> None:
|
||||
cutoff = now - self.window
|
||||
self._hits[ip] = [t for t in self._hits[ip] if t > cutoff]
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
# Only rate-limit the waitlist POST endpoint
|
||||
if request.method != "POST" or request.url.path != "/api/v1/waitlist":
|
||||
return await call_next(request)
|
||||
|
||||
ip = _get_client_ip(request)
|
||||
now = time.monotonic()
|
||||
|
||||
with self._lock:
|
||||
self._prune(ip, now)
|
||||
if len(self._hits[ip]) >= self.per_minute:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Too many requests. Please try again later."},
|
||||
headers={"Retry-After": str(self.window)},
|
||||
)
|
||||
self._hits[ip].append(now)
|
||||
|
||||
return await call_next(request)
|
||||
48
app/routes.py
Normal file
48
app/routes.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.rate_limit import _get_client_ip
|
||||
from app.schemas import WaitlistRequest, WaitlistResponse
|
||||
from app.models import WaitlistEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/waitlist", response_model=WaitlistResponse)
|
||||
async def join_waitlist(
|
||||
body: WaitlistRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> WaitlistResponse:
|
||||
"""
|
||||
Add an email to the waitlist.
|
||||
|
||||
- Honeypot: if `website` field is non-empty, silently succeed (bot trap).
|
||||
- Duplicate emails: idempotent — returns success without error.
|
||||
- Stores the Cloudflare-resolved client IP for analytics (not exposed).
|
||||
"""
|
||||
# Honeypot — bots fill hidden fields; silently "succeed"
|
||||
if body.website:
|
||||
return WaitlistResponse()
|
||||
|
||||
email = body.email.lower().strip()
|
||||
ip = _get_client_ip(request)
|
||||
|
||||
# Check for existing entry — idempotent
|
||||
existing = await db.execute(
|
||||
select(WaitlistEntry.id).where(WaitlistEntry.email == email)
|
||||
)
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
return WaitlistResponse()
|
||||
|
||||
entry = WaitlistEntry(email=email, ip_address=ip, source="website")
|
||||
db.add(entry)
|
||||
await db.commit()
|
||||
|
||||
logger.info("New waitlist signup: %s", email[:3] + "***")
|
||||
return WaitlistResponse()
|
||||
12
app/schemas.py
Normal file
12
app/schemas.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
|
||||
class WaitlistRequest(BaseModel):
|
||||
email: EmailStr
|
||||
# Honeypot field — must be empty. Bots tend to fill hidden fields.
|
||||
website: str = Field(default="", max_length=0)
|
||||
|
||||
|
||||
class WaitlistResponse(BaseModel):
|
||||
ok: bool = True
|
||||
message: str = "You're on the list!"
|
||||
59
app/security.py
Normal file
59
app/security.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Security middleware stack.
|
||||
|
||||
1. RequestSizeLimiter — reject bodies > 4 KB (waitlist only needs ~100 bytes)
|
||||
2. OriginValidator — in production, reject requests without a valid Origin/Referer
|
||||
"""
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class RequestSizeLimiter(BaseHTTPMiddleware):
|
||||
"""Reject request bodies larger than max_bytes."""
|
||||
|
||||
MAX_BYTES = 4_096 # 4 KB — more than enough for a JSON email payload
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
content_length = request.headers.get("content-length")
|
||||
if content_length and int(content_length) > self.MAX_BYTES:
|
||||
return JSONResponse(
|
||||
status_code=413,
|
||||
content={"detail": "Request body too large."},
|
||||
)
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class OriginValidator(BaseHTTPMiddleware):
|
||||
"""
|
||||
In production, only allow requests whose Origin or Referer
|
||||
matches the allowed origins list. This mitigates CSRF/cross-origin abuse.
|
||||
|
||||
Skipped in development so local testing works without custom headers.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
if settings.ENVIRONMENT != "production":
|
||||
return await call_next(request)
|
||||
|
||||
# Only check mutating methods
|
||||
if request.method not in ("POST", "PUT", "PATCH", "DELETE"):
|
||||
return await call_next(request)
|
||||
|
||||
origin = request.headers.get("origin") or ""
|
||||
referer = request.headers.get("referer") or ""
|
||||
|
||||
allowed = settings.origins_list
|
||||
origin_ok = any(origin.startswith(o) for o in allowed) if origin else False
|
||||
referer_ok = any(referer.startswith(o) for o in allowed) if referer else False
|
||||
|
||||
if not origin_ok and not referer_ok:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Forbidden."},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
32
docker-compose.yml
Normal file
32
docker-compose.yml
Normal file
@@ -0,0 +1,32 @@
|
||||
services:
|
||||
db:
|
||||
image: postgres:16-alpine
|
||||
environment:
|
||||
POSTGRES_USER: waitlist
|
||||
POSTGRES_PASSWORD: changeme
|
||||
POSTGRES_DB: waitlist_db
|
||||
ports:
|
||||
- "5433:5432"
|
||||
volumes:
|
||||
- pgdata:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U waitlist -d waitlist_db"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
|
||||
app:
|
||||
build: .
|
||||
ports:
|
||||
- "8001:8001"
|
||||
environment:
|
||||
DATABASE_URL: postgresql+asyncpg://waitlist:changeme@db:5432/waitlist_db
|
||||
ALLOWED_ORIGINS: https://adiuvai.com,https://www.adiuvai.com
|
||||
RATE_LIMIT_PER_MINUTE: 5
|
||||
ENVIRONMENT: production
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
|
||||
volumes:
|
||||
pgdata:
|
||||
10
requirements.txt
Normal file
10
requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
fastapi[standard]>=0.115,<1.0
|
||||
uvicorn[standard]>=0.34,<1.0
|
||||
gunicorn>=23,<24
|
||||
sqlalchemy[asyncio]>=2.0,<3.0
|
||||
asyncpg>=0.30,<1.0
|
||||
alembic>=1.14,<2.0
|
||||
pydantic>=2.0,<3.0
|
||||
pydantic-settings>=2.0,<3.0
|
||||
email-validator>=2.0,<3.0
|
||||
python-dotenv>=1.0,<2.0
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
109
tests/test_waitlist.py
Normal file
109
tests/test_waitlist.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.main import app
|
||||
from app.models import Base
|
||||
from app.db import get_db
|
||||
from app.rate_limit import _hits_store
|
||||
|
||||
# Use SQLite for tests (no Postgres dependency)
|
||||
TEST_DB_URL = "sqlite+aiosqlite:///./test_waitlist.db"
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_session():
|
||||
engine = create_async_engine(TEST_DB_URL)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
session_factory = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(db_session):
|
||||
async def _override_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override_db
|
||||
|
||||
# Reset rate limiter state between tests
|
||||
_hits_store.clear()
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_waitlist_success(client):
|
||||
resp = await client.post(
|
||||
"/api/v1/waitlist",
|
||||
json={"email": "user@example.com"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is True
|
||||
assert "list" in data["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_email_is_idempotent(client):
|
||||
payload = {"email": "dup@example.com"}
|
||||
r1 = await client.post("/api/v1/waitlist", json=payload)
|
||||
r2 = await client.post("/api/v1/waitlist", json=payload)
|
||||
assert r1.status_code == 200
|
||||
assert r2.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_email_rejected(client):
|
||||
resp = await client.post(
|
||||
"/api/v1/waitlist",
|
||||
json={"email": "not-an-email"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_honeypot_silently_succeeds(client):
|
||||
resp = await client.post(
|
||||
"/api/v1/waitlist",
|
||||
json={"email": "bot@spam.com", "website": "http://spam.site"},
|
||||
)
|
||||
# Honeypot field filled → validation error (max_length=0)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_email_rejected(client):
|
||||
resp = await client.post("/api/v1/waitlist", json={})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_endpoint(client):
|
||||
resp = await client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit(client):
|
||||
"""Submit more than the per-minute limit and expect 429."""
|
||||
for i in range(6):
|
||||
resp = await client.post(
|
||||
"/api/v1/waitlist",
|
||||
json={"email": f"rate{i}@example.com"},
|
||||
)
|
||||
# The 6th request should be rate-limited (limit is 5)
|
||||
assert resp.status_code == 429
|
||||
assert "Retry-After" in resp.headers
|
||||
Reference in New Issue
Block a user