796 lines
28 KiB
Python
796 lines
28 KiB
Python
"""Auth routes: register, login, refresh, me, OAuth social login, onboarding.
|
|
|
|
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
|
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
|
SHA-256 hashes so plaintext never reaches the DB.
|
|
|
|
OAuth (Google):
|
|
GET /auth/oauth/{provider}/authorize — returns consent-screen URL + state
|
|
POST /auth/oauth/{provider}/callback — exchanges code, issues JWT tokens
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import time
|
|
import urllib.parse
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Literal
|
|
|
|
import bcrypt
|
|
from cryptography.fernet import Fernet
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from fastapi.responses import RedirectResponse
|
|
from jose import jwt
|
|
from pydantic import BaseModel, Field
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.deps import get_current_user
|
|
from app.auth.oauth_providers import GoogleOAuthProvider, generate_pkce_pair
|
|
from app.config.settings import settings
|
|
from app.core.llm import get_llm
|
|
from app.core.memory_middleware import MemoryMiddleware
|
|
from app.db import get_session
|
|
from app.models import OAuthAccount, RefreshToken, User
|
|
from app.schemas import AuthTokens, UserProfile
|
|
|
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
|
|
|
|
|
# ── OAuth provider registry ───────────────────────────────────────────
|
|
|
|
def _get_google_provider() -> GoogleOAuthProvider:
|
|
if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET:
|
|
raise HTTPException(
|
|
status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
"Google login is not configured on this server",
|
|
)
|
|
return GoogleOAuthProvider(
|
|
client_id=settings.GOOGLE_AUTH_CLIENT_ID,
|
|
client_secret=settings.GOOGLE_AUTH_CLIENT_SECRET,
|
|
redirect_uri=settings.OAUTH_REDIRECT_URI,
|
|
)
|
|
|
|
|
|
_PROVIDERS = {"google": _get_google_provider}
|
|
|
|
# In-memory state store: state → (code_verifier, expires_at_epoch_s)
|
|
# Production note: replace with Redis for multi-process deployments.
|
|
_pending_states: dict[str, tuple[str, float]] = {}
|
|
_STATE_TTL_SECONDS = 600 # 10 minutes
|
|
|
|
|
|
# ── Internal helpers ─────────────────────────────────────────────────
|
|
|
|
|
|
def _hash_password(password: str) -> str:
|
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
|
|
|
|
|
def _verify_password(password: str, hashed: str) -> bool:
|
|
return bcrypt.checkpw(password.encode(), hashed.encode())
|
|
|
|
|
|
def _hash_token(plain_token: str) -> str:
|
|
"""SHA-256 of the plain refresh token string."""
|
|
return hashlib.sha256(plain_token.encode()).hexdigest()
|
|
|
|
|
|
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
|
"""Return (signed JWT, expires_at_ms)."""
|
|
now = int(time.time())
|
|
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
|
payload = {
|
|
"sub": user_id,
|
|
"email": email,
|
|
"tier": tier,
|
|
"exp": exp,
|
|
"iat": now,
|
|
}
|
|
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
|
return token, exp * 1000 # ms for client
|
|
|
|
|
|
# ── Request bodies ────────────────────────────────────────────────────
|
|
|
|
|
|
class _RegisterRequest(BaseModel):
|
|
email: str
|
|
password: str
|
|
name: str | None = None
|
|
surname: str | None = None
|
|
|
|
|
|
class _LoginRequest(BaseModel):
|
|
email: str
|
|
password: str
|
|
|
|
|
|
class _RefreshRequest(BaseModel):
|
|
refresh_token: str
|
|
|
|
|
|
# ── Routes ────────────────────────────────────────────────────────────
|
|
|
|
|
|
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
|
async def register(
|
|
body: _RegisterRequest,
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> AuthTokens:
|
|
"""Create a new account and return JWT tokens."""
|
|
existing = await db.execute(select(User).where(User.email == body.email))
|
|
if existing.scalar_one_or_none() is not None:
|
|
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
|
|
|
user = User(
|
|
id=str(uuid.uuid4()),
|
|
email=body.email,
|
|
name=body.name,
|
|
surname=body.surname,
|
|
password_hash=_hash_password(body.password),
|
|
tier="free",
|
|
encryption_key=Fernet.generate_key().decode(),
|
|
)
|
|
db.add(user)
|
|
await db.flush() # get user.id without committing
|
|
|
|
plain_token = str(uuid.uuid4())
|
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
|
)
|
|
rt = RefreshToken(
|
|
user_id=user.id,
|
|
token_hash=_hash_token(plain_token),
|
|
expires_at=expires_at,
|
|
)
|
|
db.add(rt)
|
|
await db.commit()
|
|
|
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
|
return AuthTokens(
|
|
access_token=access_token,
|
|
refresh_token=plain_token,
|
|
expires_at=expires_at_ms,
|
|
)
|
|
|
|
|
|
@router.post("/login", response_model=AuthTokens)
|
|
async def login(
|
|
body: _LoginRequest,
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> AuthTokens:
|
|
"""Validate credentials and return JWT tokens."""
|
|
result = await db.execute(select(User).where(User.email == body.email))
|
|
user = result.scalar_one_or_none()
|
|
if user is None or not _verify_password(body.password, user.password_hash):
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
|
|
|
plain_token = str(uuid.uuid4())
|
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
|
)
|
|
rt = RefreshToken(
|
|
user_id=user.id,
|
|
token_hash=_hash_token(plain_token),
|
|
expires_at=expires_at,
|
|
)
|
|
db.add(rt)
|
|
await db.commit()
|
|
|
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
|
return AuthTokens(
|
|
access_token=access_token,
|
|
refresh_token=plain_token,
|
|
expires_at=expires_at_ms,
|
|
)
|
|
|
|
|
|
@router.post("/refresh", response_model=AuthTokens)
|
|
async def refresh(
|
|
body: _RefreshRequest,
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> AuthTokens:
|
|
"""Rotate a refresh token and return a new token pair."""
|
|
token_hash = _hash_token(body.refresh_token)
|
|
result = await db.execute(
|
|
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
|
)
|
|
rt = result.scalar_one_or_none()
|
|
|
|
now = datetime.now(timezone.utc)
|
|
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
|
|
|
# Rotate: delete old token, issue new one.
|
|
await db.delete(rt)
|
|
|
|
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
|
user = user_result.scalar_one_or_none()
|
|
if user is None:
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
|
|
|
plain_token = str(uuid.uuid4())
|
|
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
|
new_rt = RefreshToken(
|
|
user_id=user.id,
|
|
token_hash=_hash_token(plain_token),
|
|
expires_at=new_expires,
|
|
)
|
|
db.add(new_rt)
|
|
await db.commit()
|
|
|
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
|
return AuthTokens(
|
|
access_token=access_token,
|
|
refresh_token=plain_token,
|
|
expires_at=expires_at_ms,
|
|
)
|
|
|
|
|
|
class _UpdateProfileRequest(BaseModel):
|
|
name: str | None = None
|
|
surname: str | None = None
|
|
|
|
|
|
@router.get("/me", response_model=UserProfile)
|
|
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
|
"""Return the profile for the authenticated user."""
|
|
return current_user
|
|
|
|
|
|
@router.put("/me", response_model=UserProfile)
|
|
async def update_profile(
|
|
body: _UpdateProfileRequest,
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> UserProfile:
|
|
"""Update the authenticated user's name and surname."""
|
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
|
user = result.scalar_one()
|
|
|
|
if body.name is not None:
|
|
user.name = body.name
|
|
if body.surname is not None:
|
|
user.surname = body.surname
|
|
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
|
|
return UserProfile(
|
|
id=user.id,
|
|
email=user.email,
|
|
name=user.name,
|
|
surname=user.surname,
|
|
avatar_url=user.avatar_url,
|
|
tier=current_user.tier,
|
|
)
|
|
|
|
|
|
# ── OAuth helpers ─────────────────────────────────────────────────────
|
|
|
|
|
|
async def _issue_refresh_token(user: User, db: AsyncSession) -> tuple[str, AuthTokens]:
|
|
"""Create a refresh token row and return (plain_token, AuthTokens)."""
|
|
plain_token = str(uuid.uuid4())
|
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
|
)
|
|
rt = RefreshToken(
|
|
user_id=user.id,
|
|
token_hash=_hash_token(plain_token),
|
|
expires_at=expires_at,
|
|
)
|
|
db.add(rt)
|
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
|
return plain_token, AuthTokens(
|
|
access_token=access_token,
|
|
refresh_token=plain_token,
|
|
expires_at=expires_at_ms,
|
|
)
|
|
|
|
|
|
# ── OAuth request/response schemas ───────────────────────────────────
|
|
|
|
|
|
class _OAuthAuthorizeResponse(BaseModel):
|
|
url: str
|
|
state: str
|
|
|
|
|
|
class _OAuthCallbackRequest(BaseModel):
|
|
code: str
|
|
state: str
|
|
|
|
|
|
# ── OAuth routes ──────────────────────────────────────────────────────
|
|
|
|
|
|
@router.get(
|
|
"/oauth/{provider}/web-callback",
|
|
summary="Web-facing OAuth redirect — bounces to the adiuvai:// deep link",
|
|
include_in_schema=False,
|
|
)
|
|
async def oauth_web_callback(
|
|
provider: Literal["google"],
|
|
code: str,
|
|
state: str,
|
|
) -> RedirectResponse:
|
|
"""Google redirects here after user consent.
|
|
|
|
This endpoint immediately redirects to the Electron deep-link URI so the
|
|
desktop app receives the authorization code. It is intentionally simple —
|
|
no state validation here (the Electron app + backend callback do that).
|
|
|
|
Registered in Google Cloud Console as:
|
|
http://localhost:8000/api/v1/auth/oauth/google/web-callback (dev)
|
|
https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback (prod)
|
|
"""
|
|
params = urllib.parse.urlencode({"code": code, "state": state, "provider": provider})
|
|
deep_link = f"adiuvai://oauth/callback?{params}"
|
|
return RedirectResponse(url=deep_link, status_code=302)
|
|
|
|
|
|
@router.get(
|
|
"/oauth/{provider}/authorize",
|
|
response_model=_OAuthAuthorizeResponse,
|
|
summary="Start OAuth flow — returns the provider consent-screen URL",
|
|
)
|
|
async def oauth_authorize(
|
|
provider: Literal["google"],
|
|
) -> _OAuthAuthorizeResponse:
|
|
"""Generate a PKCE state + code_challenge and return the authorization URL.
|
|
|
|
The client opens this URL in the system browser. After the user grants
|
|
consent, the provider redirects to the deep-link URI (adiuvai://oauth/callback)
|
|
with ``code`` and ``state`` query params. The client then calls
|
|
``POST /auth/oauth/{provider}/callback`` with those values.
|
|
"""
|
|
provider_factory = _PROVIDERS.get(provider)
|
|
if provider_factory is None:
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
|
|
|
|
oauth_provider = provider_factory()
|
|
state = str(uuid.uuid4())
|
|
code_verifier, code_challenge = generate_pkce_pair()
|
|
|
|
# Purge expired states to prevent unbounded growth.
|
|
now = time.time()
|
|
expired = [s for s, (_, exp) in _pending_states.items() if exp < now]
|
|
for s in expired:
|
|
del _pending_states[s]
|
|
|
|
_pending_states[state] = (code_verifier, now + _STATE_TTL_SECONDS)
|
|
|
|
url = oauth_provider.get_authorization_url(state=state, code_challenge=code_challenge)
|
|
return _OAuthAuthorizeResponse(url=url, state=state)
|
|
|
|
|
|
@router.post(
|
|
"/oauth/{provider}/callback",
|
|
response_model=AuthTokens,
|
|
summary="Complete OAuth flow — exchange code and issue JWT tokens",
|
|
)
|
|
async def oauth_callback(
|
|
provider: Literal["google"],
|
|
body: _OAuthCallbackRequest,
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> AuthTokens:
|
|
"""Validate state, exchange the authorization code, and sign in (or register) the user.
|
|
|
|
Resolution order:
|
|
1. ``oauth_accounts`` row match → existing user, log in.
|
|
2. Email match + ``email_verified=True`` → link OAuth account to existing user.
|
|
3. No match → create new user (password_hash=None, avatar from provider).
|
|
"""
|
|
provider_factory = _PROVIDERS.get(provider)
|
|
if provider_factory is None:
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
|
|
|
|
# Validate state (CSRF protection).
|
|
now = time.time()
|
|
entry = _pending_states.pop(body.state, None)
|
|
if entry is None or entry[1] < now:
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state")
|
|
|
|
code_verifier, _ = entry
|
|
|
|
oauth_provider = provider_factory()
|
|
|
|
# Exchange code for tokens.
|
|
try:
|
|
token_data = await oauth_provider.exchange_code(
|
|
code=body.code,
|
|
code_verifier=code_verifier,
|
|
redirect_uri=settings.OAUTH_REDIRECT_URI,
|
|
)
|
|
except Exception:
|
|
raise HTTPException(
|
|
status.HTTP_400_BAD_REQUEST, "Failed to exchange authorization code"
|
|
)
|
|
|
|
access_token_google = token_data.get("access_token")
|
|
if not access_token_google:
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No access token in provider response")
|
|
|
|
# Fetch user identity.
|
|
try:
|
|
userinfo = await oauth_provider.get_userinfo(access_token_google)
|
|
except Exception:
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Failed to fetch user info from provider")
|
|
|
|
# ── Resolution order ──────────────────────────────────────────────
|
|
|
|
# 1. Existing OAuth link?
|
|
oauth_result = await db.execute(
|
|
select(OAuthAccount).where(
|
|
OAuthAccount.provider == provider,
|
|
OAuthAccount.provider_user_id == userinfo.provider_user_id,
|
|
)
|
|
)
|
|
oauth_account = oauth_result.scalar_one_or_none()
|
|
|
|
if oauth_account is not None:
|
|
user_result = await db.execute(select(User).where(User.id == oauth_account.user_id))
|
|
user = user_result.scalar_one()
|
|
# Backfill avatar if the user doesn't have one yet.
|
|
if user.avatar_url is None and userinfo.avatar_url:
|
|
user.avatar_url = userinfo.avatar_url
|
|
await db.commit()
|
|
plain_token, tokens = await _issue_refresh_token(user, db)
|
|
await db.commit()
|
|
return tokens
|
|
|
|
# 2. Email match with a verified Google email → link accounts.
|
|
if userinfo.email_verified:
|
|
email_result = await db.execute(select(User).where(User.email == userinfo.email))
|
|
existing_user = email_result.scalar_one_or_none()
|
|
|
|
if existing_user is not None:
|
|
new_link = OAuthAccount(
|
|
user_id=existing_user.id,
|
|
provider=provider,
|
|
provider_user_id=userinfo.provider_user_id,
|
|
provider_email=userinfo.email,
|
|
)
|
|
db.add(new_link)
|
|
if existing_user.avatar_url is None and userinfo.avatar_url:
|
|
existing_user.avatar_url = userinfo.avatar_url
|
|
plain_token, tokens = await _issue_refresh_token(existing_user, db)
|
|
await db.commit()
|
|
return tokens
|
|
|
|
# Guard: if the email is already taken but we couldn't auto-link (e.g.
|
|
# email_verified=False), refuse with 409 instead of hitting a DB constraint.
|
|
if not userinfo.email_verified:
|
|
conflict = await db.execute(select(User).where(User.email == userinfo.email))
|
|
if conflict.scalar_one_or_none() is not None:
|
|
raise HTTPException(
|
|
status.HTTP_409_CONFLICT,
|
|
"An account with this email already exists. "
|
|
"Please sign in with your password.",
|
|
)
|
|
|
|
# 3. New user — social-only account (no password).
|
|
new_user = User(
|
|
id=str(uuid.uuid4()),
|
|
email=userinfo.email,
|
|
name=userinfo.name,
|
|
password_hash=None,
|
|
avatar_url=userinfo.avatar_url,
|
|
tier="free",
|
|
encryption_key=Fernet.generate_key().decode(),
|
|
)
|
|
db.add(new_user)
|
|
await db.flush() # populate new_user.id
|
|
|
|
new_oauth = OAuthAccount(
|
|
user_id=new_user.id,
|
|
provider=provider,
|
|
provider_user_id=userinfo.provider_user_id,
|
|
provider_email=userinfo.email,
|
|
)
|
|
db.add(new_oauth)
|
|
|
|
plain_token, tokens = await _issue_refresh_token(new_user, db)
|
|
await db.commit()
|
|
return tokens
|
|
|
|
|
|
# ── Onboarding helpers ────────────────────────────────────────────────
|
|
|
|
|
|
async def _build_profile(user_id: str, email: str, db: AsyncSession) -> UserProfile:
|
|
"""Re-fetch and return a full UserProfile (reuses get_current_user logic)."""
|
|
|
|
# We can't call the FastAPI dependency directly, but we can replicate
|
|
# the core logic inline. Instead, we just re-query the same way.
|
|
from app.models import Subscription # noqa: PLC0415
|
|
|
|
result = await db.execute(
|
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
|
)
|
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
|
tier: str = result.scalar_one_or_none() or default_tier
|
|
|
|
user_result = await db.execute(
|
|
select(
|
|
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
|
User.password_hash,
|
|
).where(User.id == user_id)
|
|
)
|
|
user_row = user_result.one_or_none()
|
|
|
|
onboarding_ms: int | None = None
|
|
if user_row and user_row.onboarding_completed_at is not None:
|
|
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
|
|
|
|
memory_dict: dict[str, str] = {}
|
|
try:
|
|
mw = MemoryMiddleware(db)
|
|
blocks = await mw.list_core_blocks(user_id)
|
|
memory_dict = {b["label"]: b["value"] for b in blocks}
|
|
except Exception:
|
|
pass
|
|
|
|
return UserProfile(
|
|
id=user_id,
|
|
email=email,
|
|
name=user_row.name if user_row else None,
|
|
surname=user_row.surname if user_row else None,
|
|
avatar_url=user_row.avatar_url if user_row else None,
|
|
has_password=bool(user_row.password_hash) if user_row else False,
|
|
tier=tier,
|
|
onboarding_completed_at=onboarding_ms,
|
|
memory=memory_dict,
|
|
)
|
|
|
|
|
|
# ── Onboarding routes ────────────────────────────────────────────────
|
|
|
|
|
|
class _UpdateMemoryRequest(BaseModel):
|
|
memory: dict[str, str] = Field(default_factory=dict)
|
|
mark_onboarded: bool = False
|
|
|
|
|
|
@router.put("/me/memory", response_model=UserProfile)
|
|
async def update_memory(
|
|
body: _UpdateMemoryRequest,
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> UserProfile:
|
|
"""Update core memory key/value pairs and optionally mark onboarding complete."""
|
|
mw = MemoryMiddleware(db)
|
|
for key, value in body.memory.items():
|
|
await mw.update_core(current_user.id, key, value)
|
|
if body.mark_onboarded:
|
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
|
user = result.scalar_one()
|
|
user.onboarding_completed_at = datetime.now(timezone.utc)
|
|
await db.commit()
|
|
return await _build_profile(current_user.id, current_user.email, db)
|
|
|
|
|
|
@router.post("/me/onboarding/reset")
|
|
async def reset_onboarding(
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_session),
|
|
):
|
|
"""Reset onboarding so the wizard runs again on next login."""
|
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
|
user = result.scalar_one()
|
|
user.onboarding_completed_at = None
|
|
await db.commit()
|
|
return {"status": "reset"}
|
|
|
|
|
|
class _NormalizeRequest(BaseModel):
|
|
inputs: dict[str, str]
|
|
|
|
|
|
class _NormalizeResponse(BaseModel):
|
|
normalized: dict[str, str]
|
|
|
|
|
|
@router.post("/onboarding/normalize", response_model=_NormalizeResponse)
|
|
async def normalize_onboarding(
|
|
body: _NormalizeRequest,
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
) -> _NormalizeResponse:
|
|
"""One-shot LLM normalization for free-text onboarding answers."""
|
|
if not body.inputs:
|
|
return _NormalizeResponse(normalized={})
|
|
try:
|
|
llm = get_llm(model="gpt-4o-mini", temperature=0)
|
|
prompt = (
|
|
"You normalize user onboarding answers into clean, ≤3-word canonical labels.\n"
|
|
"Return a JSON object with the same keys and normalized values.\n"
|
|
"Examples: 'i build websites' → 'Web Developer', 'tech-ish stuff' → 'Technology'\n"
|
|
f"Input: {json.dumps(body.inputs)}"
|
|
)
|
|
response = await llm.ainvoke(
|
|
[
|
|
{"role": "system", "content": "You normalize user inputs. Return JSON only."},
|
|
{"role": "user", "content": prompt},
|
|
],
|
|
)
|
|
normalized = json.loads(response.content)
|
|
return _NormalizeResponse(normalized=normalized)
|
|
except Exception:
|
|
# LLM failure must never block onboarding — return inputs unchanged
|
|
return _NormalizeResponse(normalized=body.inputs)
|
|
|
|
|
|
# ── Password management ───────────────────────────────────────────────
|
|
|
|
|
|
class _ChangePasswordRequest(BaseModel):
|
|
current_password: str = Field(min_length=1)
|
|
new_password: str = Field(min_length=8)
|
|
|
|
|
|
@router.put("/me/password", status_code=status.HTTP_200_OK)
|
|
async def change_password(
|
|
body: _ChangePasswordRequest,
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> dict[str, bool]:
|
|
"""Change the authenticated user's password.
|
|
|
|
Requires the current password for verification.
|
|
Returns 400 for social-only users (no password set).
|
|
"""
|
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
|
user = result.scalar_one()
|
|
|
|
if user.password_hash is None:
|
|
raise HTTPException(
|
|
status.HTTP_400_BAD_REQUEST,
|
|
"This account uses social login and has no password to change",
|
|
)
|
|
|
|
if not _verify_password(body.current_password, user.password_hash):
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Current password is incorrect")
|
|
|
|
user.password_hash = _hash_password(body.new_password)
|
|
await db.commit()
|
|
return {"ok": True}
|
|
|
|
|
|
# ── OAuth account management ─────────────────────────────────────────
|
|
|
|
|
|
@router.get("/me/oauth-accounts", response_model=list[dict])
|
|
async def list_oauth_accounts(
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> list[dict]:
|
|
"""List all OAuth providers linked to the authenticated user."""
|
|
result = await db.execute(
|
|
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
|
)
|
|
accounts = result.scalars().all()
|
|
return [
|
|
{
|
|
"provider": a.provider,
|
|
"provider_email": a.provider_email,
|
|
"created_at": int(a.created_at.timestamp() * 1000),
|
|
}
|
|
for a in accounts
|
|
]
|
|
|
|
|
|
@router.delete("/me/oauth-accounts/{provider}", status_code=status.HTTP_200_OK)
|
|
async def unlink_oauth_account(
|
|
provider: str,
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> dict[str, bool]:
|
|
"""Unlink an OAuth provider from the authenticated user.
|
|
|
|
Refuses if the user has no password and this is their only login method.
|
|
"""
|
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
|
user = result.scalar_one()
|
|
|
|
oauth_result = await db.execute(
|
|
select(OAuthAccount).where(
|
|
OAuthAccount.user_id == current_user.id,
|
|
OAuthAccount.provider == provider,
|
|
)
|
|
)
|
|
account = oauth_result.scalar_one_or_none()
|
|
if account is None:
|
|
raise HTTPException(status.HTTP_404_NOT_FOUND, f"No linked {provider} account found")
|
|
|
|
# Safety: don't let users lock themselves out.
|
|
all_oauth = await db.execute(
|
|
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
|
)
|
|
oauth_count = len(all_oauth.scalars().all())
|
|
|
|
if user.password_hash is None and oauth_count <= 1:
|
|
raise HTTPException(
|
|
status.HTTP_400_BAD_REQUEST,
|
|
"Cannot unlink the only login method. Set a password first.",
|
|
)
|
|
|
|
await db.delete(account)
|
|
await db.commit()
|
|
return {"ok": True}
|
|
|
|
|
|
# ── Avatar update ─────────────────────────────────────────────────────
|
|
|
|
|
|
class _UpdateAvatarRequest(BaseModel):
|
|
avatar_url: str = Field(min_length=1)
|
|
|
|
|
|
@router.put("/me/avatar", response_model=UserProfile)
|
|
async def update_avatar(
|
|
body: _UpdateAvatarRequest,
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> UserProfile:
|
|
"""Update the authenticated user's avatar URL.
|
|
|
|
Accepts {"avatar_url": "https://..."} — the client uploads the image
|
|
to its own storage and passes the resulting URL here.
|
|
"""
|
|
if not body.avatar_url.startswith(("https://", "http://", "data:image/")):
|
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid avatar URL")
|
|
|
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
|
user = result.scalar_one()
|
|
user.avatar_url = body.avatar_url
|
|
await db.commit()
|
|
|
|
return await _build_profile(current_user.id, current_user.email, db)
|
|
|
|
|
|
# ── Account deletion ─────────────────────────────────────────────────
|
|
|
|
|
|
@router.delete("/me", status_code=status.HTTP_200_OK)
|
|
async def delete_account(
|
|
current_user: UserProfile = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_session),
|
|
) -> dict[str, bool]:
|
|
"""Permanently delete the authenticated user's account.
|
|
|
|
Cascades: refresh tokens, OAuth accounts, subscription, and all memory
|
|
rows are deleted via SQLAlchemy relationship cascades. Stripe subscription
|
|
is cancelled if active.
|
|
"""
|
|
# Cancel Stripe subscription if present.
|
|
try:
|
|
from app.billing.stripe_service import stripe_service # noqa: PLC0415
|
|
await stripe_service.cancel_subscription(current_user.id, db)
|
|
except HTTPException:
|
|
pass # No subscription — that's fine
|
|
|
|
# Delete all memory rows (core, associative, episodic, proactive).
|
|
try:
|
|
from app.models import ( # noqa: PLC0415
|
|
MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive,
|
|
)
|
|
for model in (MemoryCore, MemoryAssociative, MemoryEpisodic, MemoryProactive):
|
|
await db.execute(
|
|
model.__table__.delete().where(model.user_id == current_user.id)
|
|
)
|
|
except Exception:
|
|
pass # Non-critical — cascade on User will handle most
|
|
|
|
# Delete the user row — cascades handle refresh_tokens, oauth_accounts, subscription.
|
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
|
user = result.scalar_one()
|
|
await db.delete(user)
|
|
await db.commit()
|
|
|
|
return {"ok": True}
|