- Add services/auth/app/config.py with JWT_PRIVATE_KEY and JWT_PUBLIC_KEY (Auth Service local config - private key never leaves this service) - Update routes.py: sign tokens with RS256 private key - Update deps.py + verify.py: verify tokens with RS256 public key - Update shared/config.py: replace JWT_SECRET/JWT_ALGORITHM with JWT_PUBLIC_KEY (for optional local verification by other services) - Add sys.path fix in main.py for local dev without PYTHONPATH
250 lines
7.6 KiB
Python
250 lines
7.6 KiB
Python
"""Auth routes: register, login, refresh, me.
|
|
|
|
Extracted from app/api/routes/auth.py — uses shared.* imports instead of app.*.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import time
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
import bcrypt
|
|
from cryptography.fernet import Fernet
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from jose import jwt
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from shared.config import settings
|
|
from shared.db import get_session
|
|
from shared.models import RefreshToken, Subscription, User
|
|
from shared.schemas import AuthTokens, UserProfile
|
|
|
|
from app.config import auth_settings
|
|
from app.deps import get_current_user
|
|
|
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
|
|
|
|
|
# ── Internal helpers ─────────────────────────────────────────────────
|
|
|
|
|
|
def _hash_password(password: str) -> str:
|
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
|
|
|
|
|
def _verify_password(password: str, hashed: str) -> bool:
|
|
return bcrypt.checkpw(password.encode(), hashed.encode())
|
|
|
|
|
|
def _hash_token(plain_token: str) -> str:
|
|
"""SHA-256 of the plain refresh token string."""
|
|
return hashlib.sha256(plain_token.encode()).hexdigest()
|
|
|
|
|
|
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
|
"""Return (RS256-signed JWT, expires_at_ms)."""
|
|
now = int(time.time())
|
|
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
|
payload = {
|
|
"sub": user_id,
|
|
"email": email,
|
|
"tier": tier,
|
|
"exp": exp,
|
|
"iat": now,
|
|
}
|
|
token = jwt.encode(payload, auth_settings.JWT_PRIVATE_KEY, algorithm="RS256")
|
|
return token, exp * 1000 # ms for client
|
|
|
|
|
|
async def _get_live_tier(db: AsyncSession, user_id: str) -> str:
|
|
"""Fetch authoritative tier from subscriptions table."""
|
|
result = await db.execute(
|
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
|
)
|
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
|
return result.scalar_one_or_none() or default_tier
|
|
|
|
|
|
# ── Request bodies ────────────────────────────────────────────────────
|
|
|
|
|
|
class _RegisterRequest(BaseModel):
|
|
email: str
|
|
password: str
|
|
name: str | None = None
|
|
surname: str | None = None
|
|
|
|
|
|
class _LoginRequest(BaseModel):
|
|
email: str
|
|
password: str
|
|
|
|
|
|
class _RefreshRequest(BaseModel):
|
|
refresh_token: str
|
|
|
|
|
|
class _UpdateProfileRequest(BaseModel):
|
|
name: str | None = None
|
|
surname: str | None = None
|
|
|
|
|
|
# ── Routes ────────────────────────────────────────────────────────────
|
|
|
|
|
|
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
|
async def register(
|
|
body: _RegisterRequest,
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> AuthTokens:
|
|
"""Create a new account and return JWT tokens."""
|
|
existing = await db.execute(select(User).where(User.email == body.email))
|
|
if existing.scalar_one_or_none() is not None:
|
|
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
|
|
|
user = User(
|
|
id=str(uuid.uuid4()),
|
|
email=body.email,
|
|
name=body.name,
|
|
surname=body.surname,
|
|
password_hash=_hash_password(body.password),
|
|
tier="free",
|
|
encryption_key=Fernet.generate_key().decode(),
|
|
)
|
|
db.add(user)
|
|
await db.flush()
|
|
|
|
plain_token = str(uuid.uuid4())
|
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
|
)
|
|
rt = RefreshToken(
|
|
user_id=user.id,
|
|
token_hash=_hash_token(plain_token),
|
|
expires_at=expires_at,
|
|
)
|
|
db.add(rt)
|
|
await db.commit()
|
|
|
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
|
return AuthTokens(
|
|
access_token=access_token,
|
|
refresh_token=plain_token,
|
|
expires_at=expires_at_ms,
|
|
)
|
|
|
|
|
|
@router.post("/login", response_model=AuthTokens)
|
|
async def login(
|
|
body: _LoginRequest,
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> AuthTokens:
|
|
"""Validate credentials and return JWT tokens."""
|
|
result = await db.execute(select(User).where(User.email == body.email))
|
|
user = result.scalar_one_or_none()
|
|
if user is None or not _verify_password(body.password, user.password_hash):
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
|
|
|
# Fetch live tier for the JWT claim
|
|
tier = await _get_live_tier(db, user.id)
|
|
|
|
plain_token = str(uuid.uuid4())
|
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
|
)
|
|
rt = RefreshToken(
|
|
user_id=user.id,
|
|
token_hash=_hash_token(plain_token),
|
|
expires_at=expires_at,
|
|
)
|
|
db.add(rt)
|
|
await db.commit()
|
|
|
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
|
|
return AuthTokens(
|
|
access_token=access_token,
|
|
refresh_token=plain_token,
|
|
expires_at=expires_at_ms,
|
|
)
|
|
|
|
|
|
@router.post("/refresh", response_model=AuthTokens)
|
|
async def refresh(
|
|
body: _RefreshRequest,
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> AuthTokens:
|
|
"""Rotate a refresh token and return a new token pair."""
|
|
token_hash = _hash_token(body.refresh_token)
|
|
result = await db.execute(
|
|
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
|
)
|
|
rt = result.scalar_one_or_none()
|
|
|
|
now = datetime.now(timezone.utc)
|
|
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
|
|
|
await db.delete(rt)
|
|
|
|
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
|
user = user_result.scalar_one_or_none()
|
|
if user is None:
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
|
|
|
# Fetch live tier for the new JWT
|
|
tier = await _get_live_tier(db, user.id)
|
|
|
|
plain_token = str(uuid.uuid4())
|
|
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
|
new_rt = RefreshToken(
|
|
user_id=user.id,
|
|
token_hash=_hash_token(plain_token),
|
|
expires_at=new_expires,
|
|
)
|
|
db.add(new_rt)
|
|
await db.commit()
|
|
|
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
|
|
return AuthTokens(
|
|
access_token=access_token,
|
|
refresh_token=plain_token,
|
|
expires_at=expires_at_ms,
|
|
)
|
|
|
|
|
|
@router.get("/me", response_model=UserProfile)
|
|
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
|
"""Return the profile for the authenticated user."""
|
|
return current_user
|
|
|
|
|
|
@router.put("/me", response_model=UserProfile)
|
|
async def update_profile(
|
|
body: _UpdateProfileRequest,
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> UserProfile:
|
|
"""Update the authenticated user's name and surname."""
|
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
|
user = result.scalar_one()
|
|
|
|
if body.name is not None:
|
|
user.name = body.name
|
|
if body.surname is not None:
|
|
user.surname = body.surname
|
|
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
|
|
return UserProfile(
|
|
id=user.id,
|
|
email=user.email,
|
|
name=user.name,
|
|
surname=user.surname,
|
|
tier=current_user.tier,
|
|
)
|