Initial commit: waitlist microservice
This commit is contained in:
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
21
app/config.py
Normal file
21
app/config.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
DATABASE_URL: str = "postgresql+asyncpg://waitlist:changeme@localhost:5432/waitlist_db"
|
||||
ALLOWED_ORIGINS: str = "https://adiuvai.com,https://www.adiuvai.com"
|
||||
RATE_LIMIT_PER_MINUTE: int = 5
|
||||
ENVIRONMENT: str = "development"
|
||||
|
||||
model_config = {"env_file": ".env", "env_file_encoding": "utf-8"}
|
||||
|
||||
@property
|
||||
def origins_list(self) -> list[str]:
|
||||
return [o.strip() for o in self.ALLOWED_ORIGINS.split(",") if o.strip()]
|
||||
|
||||
@property
|
||||
def sync_database_url(self) -> str:
|
||||
return self.DATABASE_URL.replace("+asyncpg", "+psycopg2")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
17
app/db.py
Normal file
17
app/db.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from app.config import settings
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.DATABASE_URL,
|
||||
pool_size=5,
|
||||
max_overflow=10,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
async def get_db() -> AsyncSession:
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
59
app/main.py
Normal file
59
app/main.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import logging
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.config import settings
|
||||
from app.rate_limit import RateLimiter
|
||||
from app.routes import router
|
||||
from app.security import OriginValidator, RequestSizeLimiter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Waitlist service starting (env=%s)", settings.ENVIRONMENT)
|
||||
yield
|
||||
logger.info("Waitlist service shutting down")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="adiuvAI Waitlist",
|
||||
version="1.0.0",
|
||||
docs_url="/docs" if settings.ENVIRONMENT != "production" else None,
|
||||
redoc_url=None,
|
||||
openapi_url="/openapi.json" if settings.ENVIRONMENT != "production" else None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# ── Middleware stack (outermost runs first) ──────────────────────────
|
||||
|
||||
# 1. CORS — locked to allowed origins
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.origins_list,
|
||||
allow_methods=["POST", "OPTIONS"],
|
||||
allow_headers=["Content-Type"],
|
||||
allow_credentials=False,
|
||||
max_age=86400,
|
||||
)
|
||||
|
||||
# 2. Rate limiter (per-IP, Cloudflare-aware)
|
||||
app.add_middleware(RateLimiter)
|
||||
|
||||
# 3. Origin / Referer validation (production only)
|
||||
app.add_middleware(OriginValidator)
|
||||
|
||||
# 4. Request body size limit (4 KB)
|
||||
app.add_middleware(RequestSizeLimiter)
|
||||
|
||||
# ── Routes ───────────────────────────────────────────────────────────
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
28
app/models.py
Normal file
28
app/models.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import datetime
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import BigInteger, Boolean, DateTime, String
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class WaitlistEntry(Base):
|
||||
__tablename__ = "waitlist_entries"
|
||||
|
||||
id: Mapped[int] = mapped_column(
|
||||
BigInteger().with_variant(sa.Integer, "sqlite"),
|
||||
primary_key=True,
|
||||
autoincrement=True,
|
||||
)
|
||||
email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True)
|
||||
ip_address: Mapped[str | None] = mapped_column(String(45), nullable=True)
|
||||
source: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
confirmed: Mapped[bool] = mapped_column(Boolean, default=False, server_default=sa.text("0"))
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=sa.text("CURRENT_TIMESTAMP"),
|
||||
nullable=False,
|
||||
)
|
||||
76
app/rate_limit.py
Normal file
76
app/rate_limit.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
IP-based sliding-window rate limiter.
|
||||
|
||||
Cloudflare-aware: uses CF-Connecting-IP → X-Forwarded-For → client.host
|
||||
to identify the real client IP.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from threading import Lock
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
"""Extract real client IP behind Cloudflare / reverse proxy."""
|
||||
# Cloudflare always sets this when proxying
|
||||
cf_ip = request.headers.get("cf-connecting-ip")
|
||||
if cf_ip:
|
||||
return cf_ip.strip()
|
||||
|
||||
# Fallback: first entry in X-Forwarded-For (set by most reverse proxies)
|
||||
xff = request.headers.get("x-forwarded-for")
|
||||
if xff:
|
||||
return xff.split(",")[0].strip()
|
||||
|
||||
# Last resort: direct connection IP
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
# Module-level hits store so tests can clear it
|
||||
_hits_store: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
|
||||
class RateLimiter(BaseHTTPMiddleware):
|
||||
"""
|
||||
Sliding-window rate limiter keyed on client IP.
|
||||
|
||||
Only applies to POST /api/v1/waitlist.
|
||||
Returns 429 with Retry-After header when exceeded.
|
||||
"""
|
||||
|
||||
def __init__(self, app, per_minute: int = settings.RATE_LIMIT_PER_MINUTE):
|
||||
super().__init__(app)
|
||||
self.per_minute = per_minute
|
||||
self.window = 60 # seconds
|
||||
self._hits = _hits_store
|
||||
self._lock = Lock()
|
||||
|
||||
def _prune(self, ip: str, now: float) -> None:
|
||||
cutoff = now - self.window
|
||||
self._hits[ip] = [t for t in self._hits[ip] if t > cutoff]
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
# Only rate-limit the waitlist POST endpoint
|
||||
if request.method != "POST" or request.url.path != "/api/v1/waitlist":
|
||||
return await call_next(request)
|
||||
|
||||
ip = _get_client_ip(request)
|
||||
now = time.monotonic()
|
||||
|
||||
with self._lock:
|
||||
self._prune(ip, now)
|
||||
if len(self._hits[ip]) >= self.per_minute:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Too many requests. Please try again later."},
|
||||
headers={"Retry-After": str(self.window)},
|
||||
)
|
||||
self._hits[ip].append(now)
|
||||
|
||||
return await call_next(request)
|
||||
48
app/routes.py
Normal file
48
app/routes.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import get_db
|
||||
from app.rate_limit import _get_client_ip
|
||||
from app.schemas import WaitlistRequest, WaitlistResponse
|
||||
from app.models import WaitlistEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/waitlist", response_model=WaitlistResponse)
|
||||
async def join_waitlist(
|
||||
body: WaitlistRequest,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> WaitlistResponse:
|
||||
"""
|
||||
Add an email to the waitlist.
|
||||
|
||||
- Honeypot: if `website` field is non-empty, silently succeed (bot trap).
|
||||
- Duplicate emails: idempotent — returns success without error.
|
||||
- Stores the Cloudflare-resolved client IP for analytics (not exposed).
|
||||
"""
|
||||
# Honeypot — bots fill hidden fields; silently "succeed"
|
||||
if body.website:
|
||||
return WaitlistResponse()
|
||||
|
||||
email = body.email.lower().strip()
|
||||
ip = _get_client_ip(request)
|
||||
|
||||
# Check for existing entry — idempotent
|
||||
existing = await db.execute(
|
||||
select(WaitlistEntry.id).where(WaitlistEntry.email == email)
|
||||
)
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
return WaitlistResponse()
|
||||
|
||||
entry = WaitlistEntry(email=email, ip_address=ip, source="website")
|
||||
db.add(entry)
|
||||
await db.commit()
|
||||
|
||||
logger.info("New waitlist signup: %s", email[:3] + "***")
|
||||
return WaitlistResponse()
|
||||
12
app/schemas.py
Normal file
12
app/schemas.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
|
||||
class WaitlistRequest(BaseModel):
|
||||
email: EmailStr
|
||||
# Honeypot field — must be empty. Bots tend to fill hidden fields.
|
||||
website: str = Field(default="", max_length=0)
|
||||
|
||||
|
||||
class WaitlistResponse(BaseModel):
|
||||
ok: bool = True
|
||||
message: str = "You're on the list!"
|
||||
59
app/security.py
Normal file
59
app/security.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Security middleware stack.
|
||||
|
||||
1. RequestSizeLimiter — reject bodies > 4 KB (waitlist only needs ~100 bytes)
|
||||
2. OriginValidator — in production, reject requests without a valid Origin/Referer
|
||||
"""
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class RequestSizeLimiter(BaseHTTPMiddleware):
|
||||
"""Reject request bodies larger than max_bytes."""
|
||||
|
||||
MAX_BYTES = 4_096 # 4 KB — more than enough for a JSON email payload
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
content_length = request.headers.get("content-length")
|
||||
if content_length and int(content_length) > self.MAX_BYTES:
|
||||
return JSONResponse(
|
||||
status_code=413,
|
||||
content={"detail": "Request body too large."},
|
||||
)
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class OriginValidator(BaseHTTPMiddleware):
|
||||
"""
|
||||
In production, only allow requests whose Origin or Referer
|
||||
matches the allowed origins list. This mitigates CSRF/cross-origin abuse.
|
||||
|
||||
Skipped in development so local testing works without custom headers.
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
if settings.ENVIRONMENT != "production":
|
||||
return await call_next(request)
|
||||
|
||||
# Only check mutating methods
|
||||
if request.method not in ("POST", "PUT", "PATCH", "DELETE"):
|
||||
return await call_next(request)
|
||||
|
||||
origin = request.headers.get("origin") or ""
|
||||
referer = request.headers.get("referer") or ""
|
||||
|
||||
allowed = settings.origins_list
|
||||
origin_ok = any(origin.startswith(o) for o in allowed) if origin else False
|
||||
referer_ok = any(referer.startswith(o) for o in allowed) if referer else False
|
||||
|
||||
if not origin_ok and not referer_ok:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Forbidden."},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
Reference in New Issue
Block a user