From 4073863dc6ebe50fb9de6c14a677d476a0cbf455 Mon Sep 17 00:00:00 2001 From: Roberto Musso Date: Sat, 11 Apr 2026 23:38:53 +0200 Subject: [PATCH] feat: add onboarding wizard backend - migration, schema, memory routes --- ...5d1e2f3a4b5_add_onboarding_completed_at.py | 31 +++++ app/api/middleware/auth.py | 24 +++- app/api/routes/auth.py | 130 +++++++++++++++++- app/models.py | 3 + app/schemas.py | 2 + 5 files changed, 186 insertions(+), 4 deletions(-) create mode 100644 alembic/versions/c5d1e2f3a4b5_add_onboarding_completed_at.py diff --git a/alembic/versions/c5d1e2f3a4b5_add_onboarding_completed_at.py b/alembic/versions/c5d1e2f3a4b5_add_onboarding_completed_at.py new file mode 100644 index 0000000..36d63bd --- /dev/null +++ b/alembic/versions/c5d1e2f3a4b5_add_onboarding_completed_at.py @@ -0,0 +1,31 @@ +"""Add onboarding_completed_at column to users table. + +Revision ID: c5d1e2f3a4b5 +Revises: b4c0d1e2f3a4 +Create Date: 2026-04-11 00:00:00.000000 + +""" +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = "c5d1e2f3a4b5" +down_revision: Union[str, None] = "b4c0d1e2f3a4" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "users", + sa.Column("onboarding_completed_at", sa.DateTime(timezone=True), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("users", "onboarding_completed_at") diff --git a/app/api/middleware/auth.py b/app/api/middleware/auth.py index c1b302e..ccea249 100644 --- a/app/api/middleware/auth.py +++ b/app/api/middleware/auth.py @@ -65,12 +65,30 @@ async def get_current_user( default_tier = "power" if settings.ENV == "dev" else "free" tier: str = result.scalar_one_or_none() or default_tier - # Fetch name/surname/avatar_url from user row. + # Fetch name/surname/avatar_url/onboarding_completed_at from user row. user_result = await db.execute( - select(User.name, User.surname, User.avatar_url).where(User.id == user_id) + select( + User.name, User.surname, User.avatar_url, User.onboarding_completed_at, + ).where(User.id == user_id) ) user_row = user_result.one_or_none() + # Convert onboarding_completed_at to epoch ms (int) or None. + onboarding_ms: int | None = None + if user_row and user_row.onboarding_completed_at is not None: + onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000) + + # Load decrypted core memory. + from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415 + + memory_dict: dict[str, str] = {} + try: + mw = MemoryMiddleware(db) + blocks = await mw.list_core_blocks(user_id) + memory_dict = {b["label"]: b["value"] for b in blocks} + except Exception: + pass # Non-critical — return empty memory on failure + return UserProfile( id=user_id, email=email, @@ -78,4 +96,6 @@ async def get_current_user( surname=user_row.surname if user_row else None, avatar_url=user_row.avatar_url if user_row else None, tier=tier, + onboarding_completed_at=onboarding_ms, + memory=memory_dict, ) # type: ignore[arg-type] diff --git a/app/api/routes/auth.py b/app/api/routes/auth.py index 2e97295..65bdfd9 100644 --- a/app/api/routes/auth.py +++ b/app/api/routes/auth.py @@ -1,4 +1,4 @@ -"""Auth routes: register, login, refresh, me, OAuth social login. +"""Auth routes: register, login, refresh, me, OAuth social login, onboarding. Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens tables). Passwords are hashed with bcrypt; refresh tokens are stored as @@ -12,6 +12,7 @@ OAuth (Google): from __future__ import annotations import hashlib +import json import time import urllib.parse import uuid @@ -23,13 +24,15 @@ from cryptography.fernet import Fernet from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import RedirectResponse from jose import jwt -from pydantic import BaseModel +from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user from app.auth.oauth_providers import GoogleOAuthProvider, generate_pkce_pair from app.config.settings import settings +from app.core.llm import get_llm +from app.core.memory_middleware import MemoryMiddleware from app.db import get_session from app.models import OAuthAccount, RefreshToken, User from app.schemas import AuthTokens, UserProfile @@ -495,3 +498,126 @@ async def oauth_callback( plain_token, tokens = await _issue_refresh_token(new_user, db) await db.commit() return tokens + + +# ── Onboarding helpers ──────────────────────────────────────────────── + + +async def _build_profile(user_id: str, email: str, db: AsyncSession) -> UserProfile: + """Re-fetch and return a full UserProfile (reuses get_current_user logic).""" + + # We can't call the FastAPI dependency directly, but we can replicate + # the core logic inline. Instead, we just re-query the same way. + from app.models import Subscription # noqa: PLC0415 + + result = await db.execute( + select(Subscription.tier).where(Subscription.user_id == user_id) + ) + default_tier = "power" if settings.ENV == "dev" else "free" + tier: str = result.scalar_one_or_none() or default_tier + + user_result = await db.execute( + select( + User.name, User.surname, User.avatar_url, User.onboarding_completed_at, + ).where(User.id == user_id) + ) + user_row = user_result.one_or_none() + + onboarding_ms: int | None = None + if user_row and user_row.onboarding_completed_at is not None: + onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000) + + memory_dict: dict[str, str] = {} + try: + mw = MemoryMiddleware(db) + blocks = await mw.list_core_blocks(user_id) + memory_dict = {b["label"]: b["value"] for b in blocks} + except Exception: + pass + + return UserProfile( + id=user_id, + email=email, + name=user_row.name if user_row else None, + surname=user_row.surname if user_row else None, + avatar_url=user_row.avatar_url if user_row else None, + tier=tier, + onboarding_completed_at=onboarding_ms, + memory=memory_dict, + ) + + +# ── Onboarding routes ──────────────────────────────────────────────── + + +class _UpdateMemoryRequest(BaseModel): + memory: dict[str, str] = Field(default_factory=dict) + mark_onboarded: bool = False + + +@router.put("/me/memory", response_model=UserProfile) +async def update_memory( + body: _UpdateMemoryRequest, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> UserProfile: + """Update core memory key/value pairs and optionally mark onboarding complete.""" + mw = MemoryMiddleware(db) + for key, value in body.memory.items(): + await mw.update_core(current_user.id, key, value) + if body.mark_onboarded: + result = await db.execute(select(User).where(User.id == current_user.id)) + user = result.scalar_one() + user.onboarding_completed_at = datetime.now(timezone.utc) + await db.commit() + return await _build_profile(current_user.id, current_user.email, db) + + +@router.post("/me/onboarding/reset") +async def reset_onboarding( + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +): + """Reset onboarding so the wizard runs again on next login.""" + result = await db.execute(select(User).where(User.id == current_user.id)) + user = result.scalar_one() + user.onboarding_completed_at = None + await db.commit() + return {"status": "reset"} + + +class _NormalizeRequest(BaseModel): + inputs: dict[str, str] + + +class _NormalizeResponse(BaseModel): + normalized: dict[str, str] + + +@router.post("/onboarding/normalize", response_model=_NormalizeResponse) +async def normalize_onboarding( + body: _NormalizeRequest, + current_user: UserProfile = Depends(get_current_user), +) -> _NormalizeResponse: + """One-shot LLM normalization for free-text onboarding answers.""" + if not body.inputs: + return _NormalizeResponse(normalized={}) + try: + llm = get_llm(model="gpt-4o-mini", temperature=0) + prompt = ( + "You normalize user onboarding answers into clean, ≤3-word canonical labels.\n" + "Return a JSON object with the same keys and normalized values.\n" + "Examples: 'i build websites' → 'Web Developer', 'tech-ish stuff' → 'Technology'\n" + f"Input: {json.dumps(body.inputs)}" + ) + response = await llm.ainvoke( + [ + {"role": "system", "content": "You normalize user inputs. Return JSON only."}, + {"role": "user", "content": prompt}, + ], + ) + normalized = json.loads(response.content) + return _NormalizeResponse(normalized=normalized) + except Exception: + # LLM failure must never block onboarding — return inputs unchanged + return _NormalizeResponse(normalized=body.inputs) diff --git a/app/models.py b/app/models.py index 0795663..6a496b4 100644 --- a/app/models.py +++ b/app/models.py @@ -79,6 +79,9 @@ class User(Base): created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), nullable=False, server_default=func.now() ) + onboarding_completed_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True, default=None + ) updated_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() ) diff --git a/app/schemas.py b/app/schemas.py index bd08418..19afcae 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -31,6 +31,8 @@ class UserProfile(BaseModel): surname: str | None = None tier: BillingTier avatar_url: str | None = None + onboarding_completed_at: int | None = None # epoch ms, null = not onboarded + memory: dict[str, str] = Field(default_factory=dict) # decrypted core memory k/v # ── Chat ─────────────────────────────────────────────────────────────