Files
api/app/api/middleware/auth.py
2026-04-15 11:43:56 +02:00

104 lines
3.7 KiB
Python

"""Auth middleware — JWT validation dependency.
``get_current_user`` is the FastAPI dependency used by all protected routes.
It decodes the Bearer JWT (identity + expiry), then fetches the current tier
from the ``subscriptions`` table so that tier changes take effect immediately
without requiring token re-issue.
Exempt routes (no JWT required):
- POST /api/v1/auth/register
- POST /api/v1/auth/login
- POST /api/v1/billing/webhook
"""
from __future__ import annotations
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
from app.db import get_session
from app.schemas import UserProfile
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Validate a Bearer JWT and return the authenticated user.
The JWT is used for identity and expiry only. The tier is fetched live
from the ``subscriptions`` table so that upgrades/downgrades take effect
immediately. Falls back to ``'free'`` when no subscription row exists.
Raises HTTP 401 on any invalid or expired token.
"""
credentials_exc = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
if not user_id or not email:
raise credentials_exc
except JWTError:
raise credentials_exc
# Live tier lookup — subscription row is the authoritative source.
# In dev, fall back to 'power' (unlimited) so quota limits don't
# block local development when no Stripe subscription exists.
from app.models import Subscription, User # noqa: PLC0415
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier
# Fetch name/surname/avatar_url/onboarding_completed_at/password_hash from user row.
user_result = await db.execute(
select(
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
User.password_hash,
).where(User.id == user_id)
)
user_row = user_result.one_or_none()
# Convert onboarding_completed_at to epoch ms (int) or None.
onboarding_ms: int | None = None
if user_row and user_row.onboarding_completed_at is not None:
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
# Load decrypted core memory.
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
memory_dict: dict[str, str] = {}
try:
mw = MemoryMiddleware(db)
blocks = await mw.list_core_blocks(user_id)
memory_dict = {b["label"]: b["value"] for b in blocks}
except Exception:
pass # Non-critical — return empty memory on failure
return UserProfile(
id=user_id,
email=email,
name=user_row.name if user_row else None,
surname=user_row.surname if user_row else None,
avatar_url=user_row.avatar_url if user_row else None,
has_password=bool(user_row.password_hash) if user_row else False,
tier=tier,
onboarding_completed_at=onboarding_ms,
memory=memory_dict,
) # type: ignore[arg-type]