step 12
This commit is contained in:
@@ -1,33 +1,36 @@
|
||||
"""Auth routes: register, login, refresh, me.
|
||||
|
||||
Users and refresh tokens are kept in an in-memory dict until Step 12
|
||||
migrates them to PostgreSQL.
|
||||
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
||||
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
||||
SHA-256 hashes so plaintext never reaches the DB.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import bcrypt
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.config.settings import settings
|
||||
from app.db import get_session
|
||||
from app.models import RefreshToken, User
|
||||
from app.schemas import AuthTokens, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
# ── In-memory stores (replaced by PostgreSQL in Step 12) ─────────────
|
||||
_users: dict[str, dict[str, Any]] = {} # email → user record
|
||||
_refresh_tokens: dict[str, str] = {} # plain token → user_id
|
||||
|
||||
|
||||
# ── Internal helpers ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
@@ -36,30 +39,29 @@ def _verify_password(password: str, hashed: str) -> bool:
|
||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||
|
||||
|
||||
def _make_tokens(user_id: str, email: str, tier: str) -> AuthTokens:
|
||||
def _hash_token(plain_token: str) -> str:
|
||||
"""SHA-256 of the plain refresh token string."""
|
||||
return hashlib.sha256(plain_token.encode()).hexdigest()
|
||||
|
||||
|
||||
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
||||
"""Return (signed JWT, expires_at_ms)."""
|
||||
now = int(time.time())
|
||||
access_exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
access_payload = {
|
||||
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"email": email,
|
||||
"tier": tier,
|
||||
"exp": access_exp,
|
||||
"exp": exp,
|
||||
"iat": now,
|
||||
}
|
||||
access_token = jwt.encode(
|
||||
access_payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
refresh_token = str(uuid.uuid4())
|
||||
_refresh_tokens[refresh_token] = user_id
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_at=access_exp * 1000, # milliseconds for client
|
||||
)
|
||||
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||
return token, exp * 1000 # ms for client
|
||||
|
||||
|
||||
# ── Request bodies ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _RegisterRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
@@ -76,40 +78,117 @@ class _RefreshRequest(BaseModel):
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
||||
async def register(body: _RegisterRequest) -> AuthTokens:
|
||||
async def register(
|
||||
body: _RegisterRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Create a new account and return JWT tokens."""
|
||||
if body.email in _users:
|
||||
existing = await db.execute(select(User).where(User.email == body.email))
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
||||
user_id = str(uuid.uuid4())
|
||||
_users[body.email] = {
|
||||
"id": user_id,
|
||||
"email": body.email,
|
||||
"password_hash": _hash_password(body.password),
|
||||
"tier": "free",
|
||||
}
|
||||
return _make_tokens(user_id, body.email, "free")
|
||||
|
||||
user = User(
|
||||
id=str(uuid.uuid4()),
|
||||
email=body.email,
|
||||
password_hash=_hash_password(body.password),
|
||||
tier="free",
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush() # get user.id without committing
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=AuthTokens)
|
||||
async def login(body: _LoginRequest) -> AuthTokens:
|
||||
async def login(
|
||||
body: _LoginRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Validate credentials and return JWT tokens."""
|
||||
user = _users.get(body.email)
|
||||
if not user or not _verify_password(body.password, user["password_hash"]):
|
||||
result = await db.execute(select(User).where(User.email == body.email))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not _verify_password(body.password, user.password_hash):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||
return _make_tokens(user["id"], user["email"], user["tier"])
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=AuthTokens)
|
||||
async def refresh(body: _RefreshRequest) -> AuthTokens:
|
||||
async def refresh(
|
||||
body: _RefreshRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Rotate a refresh token and return a new token pair."""
|
||||
user_id = _refresh_tokens.pop(body.refresh_token, None)
|
||||
if user_id is None:
|
||||
token_hash = _hash_token(body.refresh_token)
|
||||
result = await db.execute(
|
||||
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||
)
|
||||
rt = result.scalar_one_or_none()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
||||
user = next((u for u in _users.values() if u["id"] == user_id), None)
|
||||
|
||||
# Rotate: delete old token, issue new one.
|
||||
await db.delete(rt)
|
||||
|
||||
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
||||
return _make_tokens(user["id"], user["email"], user["tier"])
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
new_rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=new_expires,
|
||||
)
|
||||
db.add(new_rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserProfile)
|
||||
|
||||
Reference in New Issue
Block a user