From ce139bbac317fe04a98ff54f28da91a46534ee91 Mon Sep 17 00:00:00 2001 From: Roberto Musso Date: Fri, 10 Apr 2026 09:20:52 +0200 Subject: [PATCH] =?UTF-8?q?feat:=20add=20OAuth=20DB=20schema=20=E2=80=94?= =?UTF-8?q?=20oauth=5Faccounts=20table,=20nullable=20password=5Fhash,=20av?= =?UTF-8?q?atar=5Furl=20on=20User?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Step 1 of Google login integration: Alembic migration for oauth_accounts + avatar_url on users, OAuthAccount model with User relationship, UserProfile schema extended with avatar_url, get_current_user updated to include avatar_url. Co-Authored-By: Claude Sonnet 4.6 --- .../b4c0d1e2f3a4_add_oauth_and_avatar.py | 56 +++++ app/api/middleware/auth.py | 5 +- app/api/routes/auth.py | 228 +++++++++++++++++- app/auth/__init__.py | 1 + app/auth/oauth_providers.py | 135 +++++++++++ app/config/settings.py | 8 + app/models.py | 25 +- app/schemas.py | 1 + 8 files changed, 454 insertions(+), 5 deletions(-) create mode 100644 alembic/versions/b4c0d1e2f3a4_add_oauth_and_avatar.py create mode 100644 app/auth/__init__.py create mode 100644 app/auth/oauth_providers.py diff --git a/alembic/versions/b4c0d1e2f3a4_add_oauth_and_avatar.py b/alembic/versions/b4c0d1e2f3a4_add_oauth_and_avatar.py new file mode 100644 index 0000000..8b9b34e --- /dev/null +++ b/alembic/versions/b4c0d1e2f3a4_add_oauth_and_avatar.py @@ -0,0 +1,56 @@ +"""Add oauth_accounts table, nullable password_hash, avatar_url to users. + +Revision ID: b4c0d1e2f3a4 +Revises: a3b9c0d1e2f3 +Create Date: 2026-04-10 00:00:00.000000 + +""" +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision: str = "b4c0d1e2f3a4" +down_revision: Union[str, None] = "a3b9c0d1e2f3" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ── users: make password_hash nullable (social users have no password) ── + op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=True) + + # ── users: add avatar_url ───────────────────────────────────────────── + op.add_column("users", sa.Column("avatar_url", sa.String(2048), nullable=True)) + + # ── oauth_accounts ──────────────────────────────────────────────────── + op.create_table( + "oauth_accounts", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("provider", sa.String(50), nullable=False), + sa.Column("provider_user_id", sa.String(255), nullable=False), + sa.Column("provider_email", sa.String(255), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("now()"), + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"), + ) + op.create_index("ix_oauth_accounts_user_id", "oauth_accounts", ["user_id"]) + + +def downgrade() -> None: + op.drop_index("ix_oauth_accounts_user_id", table_name="oauth_accounts") + op.drop_table("oauth_accounts") + op.drop_column("users", "avatar_url") + op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=False) diff --git a/app/api/middleware/auth.py b/app/api/middleware/auth.py index 4fcedf5..c1b302e 100644 --- a/app/api/middleware/auth.py +++ b/app/api/middleware/auth.py @@ -65,9 +65,9 @@ async def get_current_user( default_tier = "power" if settings.ENV == "dev" else "free" tier: str = result.scalar_one_or_none() or default_tier - # Fetch name/surname from user row. + # Fetch name/surname/avatar_url from user row. user_result = await db.execute( - select(User.name, User.surname).where(User.id == user_id) + select(User.name, User.surname, User.avatar_url).where(User.id == user_id) ) user_row = user_result.one_or_none() @@ -76,5 +76,6 @@ async def get_current_user( email=email, name=user_row.name if user_row else None, surname=user_row.surname if user_row else None, + avatar_url=user_row.avatar_url if user_row else None, tier=tier, ) # type: ignore[arg-type] diff --git a/app/api/routes/auth.py b/app/api/routes/auth.py index 1ab10ea..de900d4 100644 --- a/app/api/routes/auth.py +++ b/app/api/routes/auth.py @@ -1,8 +1,12 @@ -"""Auth routes: register, login, refresh, me. +"""Auth routes: register, login, refresh, me, OAuth social login. Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens tables). Passwords are hashed with bcrypt; refresh tokens are stored as SHA-256 hashes so plaintext never reaches the DB. + +OAuth (Google): + GET /auth/oauth/{provider}/authorize — returns consent-screen URL + state + POST /auth/oauth/{provider}/callback — exchanges code, issues JWT tokens """ from __future__ import annotations @@ -11,6 +15,7 @@ import hashlib import time import uuid from datetime import datetime, timedelta, timezone +from typing import Literal import bcrypt from cryptography.fernet import Fernet @@ -21,14 +26,38 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user +from app.auth.oauth_providers import GoogleOAuthProvider, generate_pkce_pair from app.config.settings import settings from app.db import get_session -from app.models import RefreshToken, User +from app.models import OAuthAccount, RefreshToken, User from app.schemas import AuthTokens, UserProfile router = APIRouter(prefix="/auth", tags=["auth"]) +# ── OAuth provider registry ─────────────────────────────────────────── + +def _get_google_provider() -> GoogleOAuthProvider: + if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET: + raise HTTPException( + status.HTTP_503_SERVICE_UNAVAILABLE, + "Google login is not configured on this server", + ) + return GoogleOAuthProvider( + client_id=settings.GOOGLE_AUTH_CLIENT_ID, + client_secret=settings.GOOGLE_AUTH_CLIENT_SECRET, + redirect_uri=settings.OAUTH_REDIRECT_URI, + ) + + +_PROVIDERS = {"google": _get_google_provider} + +# In-memory state store: state → (code_verifier, expires_at_epoch_s) +# Production note: replace with Redis for multi-process deployments. +_pending_states: dict[str, tuple[str, float]] = {} +_STATE_TTL_SECONDS = 600 # 10 minutes + + # ── Internal helpers ───────────────────────────────────────────────── @@ -231,5 +260,200 @@ async def update_profile( email=user.email, name=user.name, surname=user.surname, + avatar_url=user.avatar_url, tier=current_user.tier, ) + + +# ── OAuth helpers ───────────────────────────────────────────────────── + + +async def _issue_refresh_token(user: User, db: AsyncSession) -> tuple[str, AuthTokens]: + """Create a refresh token row and return (plain_token, AuthTokens).""" + plain_token = str(uuid.uuid4()) + expires_at = datetime.now(timezone.utc) + timedelta( + days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS + ) + rt = RefreshToken( + user_id=user.id, + token_hash=_hash_token(plain_token), + expires_at=expires_at, + ) + db.add(rt) + access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier) + return plain_token, AuthTokens( + access_token=access_token, + refresh_token=plain_token, + expires_at=expires_at_ms, + ) + + +# ── OAuth request/response schemas ─────────────────────────────────── + + +class _OAuthAuthorizeResponse(BaseModel): + url: str + state: str + + +class _OAuthCallbackRequest(BaseModel): + code: str + state: str + + +# ── OAuth routes ────────────────────────────────────────────────────── + + +@router.get( + "/oauth/{provider}/authorize", + response_model=_OAuthAuthorizeResponse, + summary="Start OAuth flow — returns the provider consent-screen URL", +) +async def oauth_authorize( + provider: Literal["google"], +) -> _OAuthAuthorizeResponse: + """Generate a PKCE state + code_challenge and return the authorization URL. + + The client opens this URL in the system browser. After the user grants + consent, the provider redirects to the deep-link URI (adiuvai://oauth/callback) + with ``code`` and ``state`` query params. The client then calls + ``POST /auth/oauth/{provider}/callback`` with those values. + """ + provider_factory = _PROVIDERS.get(provider) + if provider_factory is None: + raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}") + + oauth_provider = provider_factory() + state = str(uuid.uuid4()) + code_verifier, code_challenge = generate_pkce_pair() + + # Purge expired states to prevent unbounded growth. + now = time.time() + expired = [s for s, (_, exp) in _pending_states.items() if exp < now] + for s in expired: + del _pending_states[s] + + _pending_states[state] = (code_verifier, now + _STATE_TTL_SECONDS) + + url = oauth_provider.get_authorization_url(state=state, code_challenge=code_challenge) + return _OAuthAuthorizeResponse(url=url, state=state) + + +@router.post( + "/oauth/{provider}/callback", + response_model=AuthTokens, + summary="Complete OAuth flow — exchange code and issue JWT tokens", +) +async def oauth_callback( + provider: Literal["google"], + body: _OAuthCallbackRequest, + db: AsyncSession = Depends(get_session), +) -> AuthTokens: + """Validate state, exchange the authorization code, and sign in (or register) the user. + + Resolution order: + 1. ``oauth_accounts`` row match → existing user, log in. + 2. Email match + ``email_verified=True`` → link OAuth account to existing user. + 3. No match → create new user (password_hash=None, avatar from provider). + """ + provider_factory = _PROVIDERS.get(provider) + if provider_factory is None: + raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}") + + # Validate state (CSRF protection). + now = time.time() + entry = _pending_states.pop(body.state, None) + if entry is None or entry[1] < now: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state") + + code_verifier, _ = entry + + oauth_provider = provider_factory() + + # Exchange code for tokens. + try: + token_data = await oauth_provider.exchange_code( + code=body.code, + code_verifier=code_verifier, + redirect_uri=settings.OAUTH_REDIRECT_URI, + ) + except Exception: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, "Failed to exchange authorization code" + ) + + access_token_google = token_data.get("access_token") + if not access_token_google: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "No access token in provider response") + + # Fetch user identity. + try: + userinfo = await oauth_provider.get_userinfo(access_token_google) + except Exception: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "Failed to fetch user info from provider") + + # ── Resolution order ────────────────────────────────────────────── + + # 1. Existing OAuth link? + oauth_result = await db.execute( + select(OAuthAccount).where( + OAuthAccount.provider == provider, + OAuthAccount.provider_user_id == userinfo.provider_user_id, + ) + ) + oauth_account = oauth_result.scalar_one_or_none() + + if oauth_account is not None: + user_result = await db.execute(select(User).where(User.id == oauth_account.user_id)) + user = user_result.scalar_one() + # Backfill avatar if the user doesn't have one yet. + if user.avatar_url is None and userinfo.avatar_url: + user.avatar_url = userinfo.avatar_url + await db.commit() + plain_token, tokens = await _issue_refresh_token(user, db) + await db.commit() + return tokens + + # 2. Email match with a verified Google email → link accounts. + if userinfo.email_verified: + email_result = await db.execute(select(User).where(User.email == userinfo.email)) + existing_user = email_result.scalar_one_or_none() + + if existing_user is not None: + new_link = OAuthAccount( + user_id=existing_user.id, + provider=provider, + provider_user_id=userinfo.provider_user_id, + provider_email=userinfo.email, + ) + db.add(new_link) + if existing_user.avatar_url is None and userinfo.avatar_url: + existing_user.avatar_url = userinfo.avatar_url + plain_token, tokens = await _issue_refresh_token(existing_user, db) + await db.commit() + return tokens + + # 3. New user — social-only account (no password). + new_user = User( + id=str(uuid.uuid4()), + email=userinfo.email, + name=userinfo.name, + password_hash=None, + avatar_url=userinfo.avatar_url, + tier="free", + encryption_key=Fernet.generate_key().decode(), + ) + db.add(new_user) + await db.flush() # populate new_user.id + + new_oauth = OAuthAccount( + user_id=new_user.id, + provider=provider, + provider_user_id=userinfo.provider_user_id, + provider_email=userinfo.email, + ) + db.add(new_oauth) + + plain_token, tokens = await _issue_refresh_token(new_user, db) + await db.commit() + return tokens diff --git a/app/auth/__init__.py b/app/auth/__init__.py new file mode 100644 index 0000000..b45e86e --- /dev/null +++ b/app/auth/__init__.py @@ -0,0 +1 @@ +"OAuth provider abstractions and utilities." diff --git a/app/auth/oauth_providers.py b/app/auth/oauth_providers.py new file mode 100644 index 0000000..3363528 --- /dev/null +++ b/app/auth/oauth_providers.py @@ -0,0 +1,135 @@ +"""OAuth 2.0 + PKCE provider abstractions. + +Each provider implements a three-step flow designed for a desktop (public) client: + + 1. get_authorization_url(state, code_challenge) → str + Build the provider's consent-screen URL. State and code_challenge are + generated server-side; the client opens this URL in the system browser. + + 2. exchange_code(code, code_verifier, redirect_uri) → dict + Exchange the short-lived authorization code for an access token. + The code_verifier proves ownership of the PKCE challenge. + + 3. get_userinfo(access_token) → OAuthUserInfo + Fetch the canonical user identity from the provider. + +Currently supported providers: + - GoogleOAuthProvider (scope: openid email profile) + +Adding a new provider: + - Implement the three methods above. + - Register in _PROVIDERS inside routes/auth.py. +""" + +from __future__ import annotations + +import base64 +import hashlib +import os +import urllib.parse +from dataclasses import dataclass + +import httpx + + +# ── Data transfer objects ───────────────────────────────────────────── + + +@dataclass +class OAuthUserInfo: + """Normalized user identity returned by any provider.""" + + provider_user_id: str + email: str + email_verified: bool + avatar_url: str | None + name: str | None + + +# ── PKCE helpers ────────────────────────────────────────────────────── + + +def generate_pkce_pair() -> tuple[str, str]: + """Generate a (code_verifier, code_challenge) pair for PKCE S256. + + The code_verifier is a random 32-byte URL-safe base64 string. + The code_challenge is SHA-256(code_verifier) base64url-encoded (no padding). + """ + code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode() + digest = hashlib.sha256(code_verifier.encode()).digest() + code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + return code_verifier, code_challenge + + +# ── Google provider ─────────────────────────────────────────────────── + + +class GoogleOAuthProvider: + """Google OAuth 2.0 provider (openid email profile scope). + + Uses Google's standard authorization endpoint with PKCE S256. + Does NOT use google-auth-oauthlib to keep the flow generic and async. + """ + + name = "google" + + _AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth" + _TOKEN_URL = "https://oauth2.googleapis.com/token" + _USERINFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" + + def __init__(self, client_id: str, client_secret: str, redirect_uri: str) -> None: + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = redirect_uri + + def get_authorization_url(self, state: str, code_challenge: str) -> str: + """Build the Google consent-screen URL.""" + params = { + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "response_type": "code", + "scope": "openid email profile", + "state": state, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "access_type": "offline", + "prompt": "select_account", + } + return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" + + async def exchange_code( + self, code: str, code_verifier: str, redirect_uri: str + ) -> dict: + """Exchange authorization code for an access token.""" + async with httpx.AsyncClient() as client: + response = await client.post( + self._TOKEN_URL, + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "code_verifier": code_verifier, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + response.raise_for_status() + return response.json() + + async def get_userinfo(self, access_token: str) -> OAuthUserInfo: + """Fetch the authenticated user's identity from Google.""" + async with httpx.AsyncClient() as client: + response = await client.get( + self._USERINFO_URL, + headers={"Authorization": f"Bearer {access_token}"}, + ) + response.raise_for_status() + data = response.json() + + return OAuthUserInfo( + provider_user_id=data["sub"], + email=data["email"], + email_verified=data.get("email_verified", False), + avatar_url=data.get("picture"), + name=data.get("name"), + ) diff --git a/app/config/settings.py b/app/config/settings.py index f9eeabd..8e09de8 100644 --- a/app/config/settings.py +++ b/app/config/settings.py @@ -41,6 +41,14 @@ class Settings(BaseSettings): # MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts). MS_TENANT_ID: str = "common" + # Google Login OAuth credentials — scope: openid email profile. + # Separate from GMAIL_CLIENT_ID/SECRET (which uses gmail.readonly scope). + GOOGLE_AUTH_CLIENT_ID: str = "" + GOOGLE_AUTH_CLIENT_SECRET: str = "" + # Deep-link URI registered in the Google Cloud Console for the desktop app. + # Must match the protocol registered in forge.config.ts. + OAUTH_REDIRECT_URI: str = "adiuvai://oauth/callback" + # Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth # tokens stored in cloud_agent_configs.oauth_token_encrypted. # Generate with: from cryptography.fernet import Fernet; Fernet.generate_key() diff --git a/app/models.py b/app/models.py index fea6054..0795663 100644 --- a/app/models.py +++ b/app/models.py @@ -69,7 +69,8 @@ class User(Base): email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True) name: Mapped[str | None] = mapped_column(String(100), nullable=True) surname: Mapped[str | None] = mapped_column(String(100), nullable=True) - password_hash: Mapped[str] = mapped_column(String(255), nullable=False) + password_hash: Mapped[str | None] = mapped_column(String(255), nullable=True) + avatar_url: Mapped[str | None] = mapped_column(String(2048), nullable=True) tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free") stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True) # Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration. @@ -88,6 +89,9 @@ class User(Base): subscription: Mapped[Subscription | None] = relationship( back_populates="user", uselist=False, cascade="all, delete-orphan" ) + oauth_accounts: Mapped[list[OAuthAccount]] = relationship( + back_populates="user", cascade="all, delete-orphan" + ) class RefreshToken(Base): @@ -108,6 +112,25 @@ class RefreshToken(Base): user: Mapped[User] = relationship(back_populates="refresh_tokens") +class OAuthAccount(Base): + __tablename__ = "oauth_accounts" + + id: Mapped[str] = mapped_column( + Uuid(as_uuid=False), primary_key=True, default=_uuid + ) + user_id: Mapped[str] = mapped_column( + Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True + ) + provider: Mapped[str] = mapped_column(String(50), nullable=False) + provider_user_id: Mapped[str] = mapped_column(String(255), nullable=False) + provider_email: Mapped[str | None] = mapped_column(String(255), nullable=True) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) + + user: Mapped[User] = relationship(back_populates="oauth_accounts") + + class Subscription(Base): __tablename__ = "subscriptions" diff --git a/app/schemas.py b/app/schemas.py index 80996ba..bd08418 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -30,6 +30,7 @@ class UserProfile(BaseModel): name: str | None = None surname: str | None = None tier: BillingTier + avatar_url: str | None = None # ── Chat ─────────────────────────────────────────────────────────────