This commit is contained in:
2026-03-03 12:39:32 +01:00
parent 9787befd4a
commit 5d485b3665
12 changed files with 999 additions and 165 deletions

View File

@@ -1,33 +1,36 @@
"""Auth routes: register, login, refresh, me.
Users and refresh tokens are kept in an in-memory dict until Step 12
migrates them to PostgreSQL.
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
SHA-256 hashes so plaintext never reaches the DB.
"""
from __future__ import annotations
import hashlib
import time
import uuid
from typing import Any
from datetime import datetime, timedelta, timezone
import bcrypt
from fastapi import APIRouter, Depends, HTTPException, status
from jose import jwt
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.config.settings import settings
from app.db import get_session
from app.models import RefreshToken, User
from app.schemas import AuthTokens, UserProfile
router = APIRouter(prefix="/auth", tags=["auth"])
# ── In-memory stores (replaced by PostgreSQL in Step 12) ─────────────
_users: dict[str, dict[str, Any]] = {} # email → user record
_refresh_tokens: dict[str, str] = {} # plain token → user_id
# ── Internal helpers ─────────────────────────────────────────────────
def _hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
@@ -36,30 +39,29 @@ def _verify_password(password: str, hashed: str) -> bool:
return bcrypt.checkpw(password.encode(), hashed.encode())
def _make_tokens(user_id: str, email: str, tier: str) -> AuthTokens:
def _hash_token(plain_token: str) -> str:
"""SHA-256 of the plain refresh token string."""
return hashlib.sha256(plain_token.encode()).hexdigest()
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
"""Return (signed JWT, expires_at_ms)."""
now = int(time.time())
access_exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
access_payload = {
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
payload = {
"sub": user_id,
"email": email,
"tier": tier,
"exp": access_exp,
"exp": exp,
"iat": now,
}
access_token = jwt.encode(
access_payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
)
refresh_token = str(uuid.uuid4())
_refresh_tokens[refresh_token] = user_id
return AuthTokens(
access_token=access_token,
refresh_token=refresh_token,
expires_at=access_exp * 1000, # milliseconds for client
)
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
return token, exp * 1000 # ms for client
# ── Request bodies ────────────────────────────────────────────────────
class _RegisterRequest(BaseModel):
email: str
password: str
@@ -76,40 +78,117 @@ class _RefreshRequest(BaseModel):
# ── Routes ────────────────────────────────────────────────────────────
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
async def register(body: _RegisterRequest) -> AuthTokens:
async def register(
body: _RegisterRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Create a new account and return JWT tokens."""
if body.email in _users:
existing = await db.execute(select(User).where(User.email == body.email))
if existing.scalar_one_or_none() is not None:
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
user_id = str(uuid.uuid4())
_users[body.email] = {
"id": user_id,
"email": body.email,
"password_hash": _hash_password(body.password),
"tier": "free",
}
return _make_tokens(user_id, body.email, "free")
user = User(
id=str(uuid.uuid4()),
email=body.email,
password_hash=_hash_password(body.password),
tier="free",
)
db.add(user)
await db.flush() # get user.id without committing
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.post("/login", response_model=AuthTokens)
async def login(body: _LoginRequest) -> AuthTokens:
async def login(
body: _LoginRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Validate credentials and return JWT tokens."""
user = _users.get(body.email)
if not user or not _verify_password(body.password, user["password_hash"]):
result = await db.execute(select(User).where(User.email == body.email))
user = result.scalar_one_or_none()
if user is None or not _verify_password(body.password, user.password_hash):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
return _make_tokens(user["id"], user["email"], user["tier"])
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.post("/refresh", response_model=AuthTokens)
async def refresh(body: _RefreshRequest) -> AuthTokens:
async def refresh(
body: _RefreshRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Rotate a refresh token and return a new token pair."""
user_id = _refresh_tokens.pop(body.refresh_token, None)
if user_id is None:
token_hash = _hash_token(body.refresh_token)
result = await db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
rt = result.scalar_one_or_none()
now = datetime.now(timezone.utc)
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
user = next((u for u in _users.values() if u["id"] == user_id), None)
# Rotate: delete old token, issue new one.
await db.delete(rt)
user_result = await db.execute(select(User).where(User.id == rt.user_id))
user = user_result.scalar_one_or_none()
if user is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
return _make_tokens(user["id"], user["email"], user["tier"])
plain_token = str(uuid.uuid4())
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
new_rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=new_expires,
)
db.add(new_rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.get("/me", response_model=UserProfile)

View File

@@ -11,9 +11,11 @@ from typing import Any
from fastapi import APIRouter, Depends, Header, Request, status
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.stripe_service import stripe_service
from app.db import get_session
from app.schemas import BillingTier, UserProfile
router = APIRouter(prefix="/billing", tags=["billing"])
@@ -44,6 +46,7 @@ async def create_checkout(
async def stripe_webhook(
request: Request,
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Handle Stripe webhook events.
@@ -51,16 +54,17 @@ async def stripe_webhook(
Returns 200 immediately when Stripe is not configured (local dev).
"""
payload = await request.body()
stripe_service.handle_webhook(payload, stripe_signature)
await stripe_service.handle_webhook(payload, stripe_signature, db)
return {"ok": True}
@router.get("/subscription", response_model=dict)
async def get_subscription(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""Return the current subscription info for the authenticated user."""
sub = stripe_service.get_subscription(current_user.id)
sub = await stripe_service.get_subscription(current_user.id, db)
if sub is None:
return {
"tier": current_user.tier,
@@ -74,7 +78,8 @@ async def get_subscription(
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
async def cancel_subscription(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Cancel the active subscription."""
stripe_service.cancel_subscription(current_user.id)
await stripe_service.cancel_subscription(current_user.id, db)
return {"ok": True}