198 lines
6.2 KiB
Python
198 lines
6.2 KiB
Python
"""Auth routes: register, login, refresh, me.
|
|
|
|
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
|
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
|
SHA-256 hashes so plaintext never reaches the DB.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import time
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
import bcrypt
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from jose import jwt
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.deps import get_current_user
|
|
from app.config.settings import settings
|
|
from app.db import get_session
|
|
from app.models import RefreshToken, User
|
|
from app.schemas import AuthTokens, UserProfile
|
|
|
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
|
|
|
|
|
# ── Internal helpers ─────────────────────────────────────────────────
|
|
|
|
|
|
def _hash_password(password: str) -> str:
|
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
|
|
|
|
|
def _verify_password(password: str, hashed: str) -> bool:
|
|
return bcrypt.checkpw(password.encode(), hashed.encode())
|
|
|
|
|
|
def _hash_token(plain_token: str) -> str:
|
|
"""SHA-256 of the plain refresh token string."""
|
|
return hashlib.sha256(plain_token.encode()).hexdigest()
|
|
|
|
|
|
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
|
"""Return (signed JWT, expires_at_ms)."""
|
|
now = int(time.time())
|
|
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
|
payload = {
|
|
"sub": user_id,
|
|
"email": email,
|
|
"tier": tier,
|
|
"exp": exp,
|
|
"iat": now,
|
|
}
|
|
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
|
return token, exp * 1000 # ms for client
|
|
|
|
|
|
# ── Request bodies ────────────────────────────────────────────────────
|
|
|
|
|
|
class _RegisterRequest(BaseModel):
|
|
email: str
|
|
password: str
|
|
|
|
|
|
class _LoginRequest(BaseModel):
|
|
email: str
|
|
password: str
|
|
|
|
|
|
class _RefreshRequest(BaseModel):
|
|
refresh_token: str
|
|
|
|
|
|
# ── Routes ────────────────────────────────────────────────────────────
|
|
|
|
|
|
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
|
async def register(
|
|
body: _RegisterRequest,
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> AuthTokens:
|
|
"""Create a new account and return JWT tokens."""
|
|
existing = await db.execute(select(User).where(User.email == body.email))
|
|
if existing.scalar_one_or_none() is not None:
|
|
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
|
|
|
user = User(
|
|
id=str(uuid.uuid4()),
|
|
email=body.email,
|
|
password_hash=_hash_password(body.password),
|
|
tier="free",
|
|
)
|
|
db.add(user)
|
|
await db.flush() # get user.id without committing
|
|
|
|
plain_token = str(uuid.uuid4())
|
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
|
)
|
|
rt = RefreshToken(
|
|
user_id=user.id,
|
|
token_hash=_hash_token(plain_token),
|
|
expires_at=expires_at,
|
|
)
|
|
db.add(rt)
|
|
await db.commit()
|
|
|
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
|
return AuthTokens(
|
|
access_token=access_token,
|
|
refresh_token=plain_token,
|
|
expires_at=expires_at_ms,
|
|
)
|
|
|
|
|
|
@router.post("/login", response_model=AuthTokens)
|
|
async def login(
|
|
body: _LoginRequest,
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> AuthTokens:
|
|
"""Validate credentials and return JWT tokens."""
|
|
result = await db.execute(select(User).where(User.email == body.email))
|
|
user = result.scalar_one_or_none()
|
|
if user is None or not _verify_password(body.password, user.password_hash):
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
|
|
|
plain_token = str(uuid.uuid4())
|
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
|
)
|
|
rt = RefreshToken(
|
|
user_id=user.id,
|
|
token_hash=_hash_token(plain_token),
|
|
expires_at=expires_at,
|
|
)
|
|
db.add(rt)
|
|
await db.commit()
|
|
|
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
|
return AuthTokens(
|
|
access_token=access_token,
|
|
refresh_token=plain_token,
|
|
expires_at=expires_at_ms,
|
|
)
|
|
|
|
|
|
@router.post("/refresh", response_model=AuthTokens)
|
|
async def refresh(
|
|
body: _RefreshRequest,
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> AuthTokens:
|
|
"""Rotate a refresh token and return a new token pair."""
|
|
token_hash = _hash_token(body.refresh_token)
|
|
result = await db.execute(
|
|
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
|
)
|
|
rt = result.scalar_one_or_none()
|
|
|
|
now = datetime.now(timezone.utc)
|
|
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
|
|
|
# Rotate: delete old token, issue new one.
|
|
await db.delete(rt)
|
|
|
|
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
|
user = user_result.scalar_one_or_none()
|
|
if user is None:
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
|
|
|
plain_token = str(uuid.uuid4())
|
|
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
|
new_rt = RefreshToken(
|
|
user_id=user.id,
|
|
token_hash=_hash_token(plain_token),
|
|
expires_at=new_expires,
|
|
)
|
|
db.add(new_rt)
|
|
await db.commit()
|
|
|
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
|
return AuthTokens(
|
|
access_token=access_token,
|
|
refresh_token=plain_token,
|
|
expires_at=expires_at_ms,
|
|
)
|
|
|
|
|
|
@router.get("/me", response_model=UserProfile)
|
|
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
|
"""Return the profile for the authenticated user."""
|
|
return current_user
|