"""Auth routes: register, login, refresh, me. Extracted from app/api/routes/auth.py — uses shared.* imports instead of app.*. """ from __future__ import annotations import hashlib import time import uuid from datetime import datetime, timedelta, timezone import bcrypt from cryptography.fernet import Fernet from fastapi import APIRouter, Depends, HTTPException, status from jose import jwt from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from shared.config import settings from shared.db import get_session from shared.models import RefreshToken, Subscription, User from shared.schemas import AuthTokens, UserProfile from app.config import auth_settings from app.deps import get_current_user router = APIRouter(prefix="/auth", tags=["auth"]) # ── Internal helpers ───────────────────────────────────────────────── def _hash_password(password: str) -> str: return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() def _verify_password(password: str, hashed: str) -> bool: return bcrypt.checkpw(password.encode(), hashed.encode()) def _hash_token(plain_token: str) -> str: """SHA-256 of the plain refresh token string.""" return hashlib.sha256(plain_token.encode()).hexdigest() def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]: """Return (RS256-signed JWT, expires_at_ms).""" now = int(time.time()) exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60 payload = { "sub": user_id, "email": email, "tier": tier, "exp": exp, "iat": now, } token = jwt.encode(payload, auth_settings.JWT_PRIVATE_KEY, algorithm="RS256") return token, exp * 1000 # ms for client async def _get_live_tier(db: AsyncSession, user_id: str) -> str: """Fetch authoritative tier from subscriptions table.""" result = await db.execute( select(Subscription.tier).where(Subscription.user_id == user_id) ) default_tier = "power" if settings.ENV == "dev" else "free" return result.scalar_one_or_none() or default_tier # ── Request bodies ──────────────────────────────────────────────────── class _RegisterRequest(BaseModel): email: str password: str name: str | None = None surname: str | None = None class _LoginRequest(BaseModel): email: str password: str class _RefreshRequest(BaseModel): refresh_token: str class _UpdateProfileRequest(BaseModel): name: str | None = None surname: str | None = None # ── Routes ──────────────────────────────────────────────────────────── @router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED) async def register( body: _RegisterRequest, db: AsyncSession = Depends(get_session), ) -> AuthTokens: """Create a new account and return JWT tokens.""" existing = await db.execute(select(User).where(User.email == body.email)) if existing.scalar_one_or_none() is not None: raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered") user = User( id=str(uuid.uuid4()), email=body.email, name=body.name, surname=body.surname, password_hash=_hash_password(body.password), tier="free", encryption_key=Fernet.generate_key().decode(), ) db.add(user) await db.flush() plain_token = str(uuid.uuid4()) expires_at = datetime.now(timezone.utc) + timedelta( days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS ) rt = RefreshToken( user_id=user.id, token_hash=_hash_token(plain_token), expires_at=expires_at, ) db.add(rt) await db.commit() access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier) return AuthTokens( access_token=access_token, refresh_token=plain_token, expires_at=expires_at_ms, ) @router.post("/login", response_model=AuthTokens) async def login( body: _LoginRequest, db: AsyncSession = Depends(get_session), ) -> AuthTokens: """Validate credentials and return JWT tokens.""" result = await db.execute(select(User).where(User.email == body.email)) user = result.scalar_one_or_none() if user is None or not _verify_password(body.password, user.password_hash): raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials") # Fetch live tier for the JWT claim tier = await _get_live_tier(db, user.id) plain_token = str(uuid.uuid4()) expires_at = datetime.now(timezone.utc) + timedelta( days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS ) rt = RefreshToken( user_id=user.id, token_hash=_hash_token(plain_token), expires_at=expires_at, ) db.add(rt) await db.commit() access_token, expires_at_ms = _make_access_token(user.id, user.email, tier) return AuthTokens( access_token=access_token, refresh_token=plain_token, expires_at=expires_at_ms, ) @router.post("/refresh", response_model=AuthTokens) async def refresh( body: _RefreshRequest, db: AsyncSession = Depends(get_session), ) -> AuthTokens: """Rotate a refresh token and return a new token pair.""" token_hash = _hash_token(body.refresh_token) result = await db.execute( select(RefreshToken).where(RefreshToken.token_hash == token_hash) ) rt = result.scalar_one_or_none() now = datetime.now(timezone.utc) if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token") await db.delete(rt) user_result = await db.execute(select(User).where(User.id == rt.user_id)) user = user_result.scalar_one_or_none() if user is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found") # Fetch live tier for the new JWT tier = await _get_live_tier(db, user.id) plain_token = str(uuid.uuid4()) new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS) new_rt = RefreshToken( user_id=user.id, token_hash=_hash_token(plain_token), expires_at=new_expires, ) db.add(new_rt) await db.commit() access_token, expires_at_ms = _make_access_token(user.id, user.email, tier) return AuthTokens( access_token=access_token, refresh_token=plain_token, expires_at=expires_at_ms, ) @router.get("/me", response_model=UserProfile) async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile: """Return the profile for the authenticated user.""" return current_user @router.put("/me", response_model=UserProfile) async def update_profile( body: _UpdateProfileRequest, current_user: UserProfile = Depends(get_current_user), db: AsyncSession = Depends(get_session), ) -> UserProfile: """Update the authenticated user's name and surname.""" result = await db.execute(select(User).where(User.id == current_user.id)) user = result.scalar_one() if body.name is not None: user.name = body.name if body.surname is not None: user.surname = body.surname await db.commit() await db.refresh(user) return UserProfile( id=user.id, email=user.email, name=user.name, surname=user.surname, tier=current_user.tier, )