"""Auth routes: register, login, refresh, me. Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens tables). Passwords are hashed with bcrypt; refresh tokens are stored as SHA-256 hashes so plaintext never reaches the DB. """ from __future__ import annotations import hashlib import time import uuid from datetime import datetime, timedelta, timezone import bcrypt from fastapi import APIRouter, Depends, HTTPException, status from jose import jwt from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user from app.config.settings import settings from app.db import get_session from app.models import RefreshToken, User from app.schemas import AuthTokens, UserProfile router = APIRouter(prefix="/auth", tags=["auth"]) # ── Internal helpers ───────────────────────────────────────────────── def _hash_password(password: str) -> str: return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() def _verify_password(password: str, hashed: str) -> bool: return bcrypt.checkpw(password.encode(), hashed.encode()) def _hash_token(plain_token: str) -> str: """SHA-256 of the plain refresh token string.""" return hashlib.sha256(plain_token.encode()).hexdigest() def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]: """Return (signed JWT, expires_at_ms).""" now = int(time.time()) exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60 payload = { "sub": user_id, "email": email, "tier": tier, "exp": exp, "iat": now, } token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM) return token, exp * 1000 # ms for client # ── Request bodies ──────────────────────────────────────────────────── class _RegisterRequest(BaseModel): email: str password: str class _LoginRequest(BaseModel): email: str password: str class _RefreshRequest(BaseModel): refresh_token: str # ── Routes ──────────────────────────────────────────────────────────── @router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED) async def register( body: _RegisterRequest, db: AsyncSession = Depends(get_session), ) -> AuthTokens: """Create a new account and return JWT tokens.""" existing = await db.execute(select(User).where(User.email == body.email)) if existing.scalar_one_or_none() is not None: raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered") user = User( id=str(uuid.uuid4()), email=body.email, password_hash=_hash_password(body.password), tier="free", ) db.add(user) await db.flush() # get user.id without committing plain_token = str(uuid.uuid4()) expires_at = datetime.now(timezone.utc) + timedelta( days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS ) rt = RefreshToken( user_id=user.id, token_hash=_hash_token(plain_token), expires_at=expires_at, ) db.add(rt) await db.commit() access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier) return AuthTokens( access_token=access_token, refresh_token=plain_token, expires_at=expires_at_ms, ) @router.post("/login", response_model=AuthTokens) async def login( body: _LoginRequest, db: AsyncSession = Depends(get_session), ) -> AuthTokens: """Validate credentials and return JWT tokens.""" result = await db.execute(select(User).where(User.email == body.email)) user = result.scalar_one_or_none() if user is None or not _verify_password(body.password, user.password_hash): raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials") plain_token = str(uuid.uuid4()) expires_at = datetime.now(timezone.utc) + timedelta( days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS ) rt = RefreshToken( user_id=user.id, token_hash=_hash_token(plain_token), expires_at=expires_at, ) db.add(rt) await db.commit() access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier) return AuthTokens( access_token=access_token, refresh_token=plain_token, expires_at=expires_at_ms, ) @router.post("/refresh", response_model=AuthTokens) async def refresh( body: _RefreshRequest, db: AsyncSession = Depends(get_session), ) -> AuthTokens: """Rotate a refresh token and return a new token pair.""" token_hash = _hash_token(body.refresh_token) result = await db.execute( select(RefreshToken).where(RefreshToken.token_hash == token_hash) ) rt = result.scalar_one_or_none() now = datetime.now(timezone.utc) if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token") # Rotate: delete old token, issue new one. await db.delete(rt) user_result = await db.execute(select(User).where(User.id == rt.user_id)) user = user_result.scalar_one_or_none() if user is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found") plain_token = str(uuid.uuid4()) new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS) new_rt = RefreshToken( user_id=user.id, token_hash=_hash_token(plain_token), expires_at=new_expires, ) db.add(new_rt) await db.commit() access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier) return AuthTokens( access_token=access_token, refresh_token=plain_token, expires_at=expires_at_ms, ) @router.get("/me", response_model=UserProfile) async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile: """Return the profile for the authenticated user.""" return current_user