Compare commits
3 Commits
a85f8fde29
...
feature/ba
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e668e3fd20 | ||
|
|
7ccdad431f | ||
|
|
4073863dc6 |
@@ -16,7 +16,7 @@ import re
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
# Alembic Config object (gives access to alembic.ini values).
|
||||
|
||||
31
alembic/versions/c5d1e2f3a4b5_add_onboarding_completed_at.py
Normal file
31
alembic/versions/c5d1e2f3a4b5_add_onboarding_completed_at.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Add onboarding_completed_at column to users table.
|
||||
|
||||
Revision ID: c5d1e2f3a4b5
|
||||
Revises: b4c0d1e2f3a4
|
||||
Create Date: 2026-04-11 00:00:00.000000
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c5d1e2f3a4b5"
|
||||
down_revision: Union[str, None] = "b4c0d1e2f3a4"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("onboarding_completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("users", "onboarding_completed_at")
|
||||
34
alembic/versions/e04100e88ace_avatar_url_varchar_to_text.py
Normal file
34
alembic/versions/e04100e88ace_avatar_url_varchar_to_text.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""avatar_url_varchar_to_text
|
||||
|
||||
Revision ID: e04100e88ace
|
||||
Revises: c5d1e2f3a4b5
|
||||
Create Date: 2026-04-13 09:13:06.733674
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'e04100e88ace'
|
||||
down_revision: Union[str, None] = 'c5d1e2f3a4b5'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column('users', 'avatar_url',
|
||||
existing_type=sa.VARCHAR(length=2048),
|
||||
type_=sa.Text(),
|
||||
existing_nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column('users', 'avatar_url',
|
||||
existing_type=sa.Text(),
|
||||
type_=sa.VARCHAR(length=2048),
|
||||
existing_nullable=True)
|
||||
@@ -65,17 +65,39 @@ async def get_current_user(
|
||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||
tier: str = result.scalar_one_or_none() or default_tier
|
||||
|
||||
# Fetch name/surname/avatar_url from user row.
|
||||
# Fetch name/surname/avatar_url/onboarding_completed_at/password_hash from user row.
|
||||
user_result = await db.execute(
|
||||
select(User.name, User.surname, User.avatar_url).where(User.id == user_id)
|
||||
select(
|
||||
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
||||
User.password_hash,
|
||||
).where(User.id == user_id)
|
||||
)
|
||||
user_row = user_result.one_or_none()
|
||||
|
||||
# Convert onboarding_completed_at to epoch ms (int) or None.
|
||||
onboarding_ms: int | None = None
|
||||
if user_row and user_row.onboarding_completed_at is not None:
|
||||
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
|
||||
|
||||
# Load decrypted core memory.
|
||||
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
|
||||
|
||||
memory_dict: dict[str, str] = {}
|
||||
try:
|
||||
mw = MemoryMiddleware(db)
|
||||
blocks = await mw.list_core_blocks(user_id)
|
||||
memory_dict = {b["label"]: b["value"] for b in blocks}
|
||||
except Exception:
|
||||
pass # Non-critical — return empty memory on failure
|
||||
|
||||
return UserProfile(
|
||||
id=user_id,
|
||||
email=email,
|
||||
name=user_row.name if user_row else None,
|
||||
surname=user_row.surname if user_row else None,
|
||||
avatar_url=user_row.avatar_url if user_row else None,
|
||||
has_password=bool(user_row.password_hash) if user_row else False,
|
||||
tier=tier,
|
||||
onboarding_completed_at=onboarding_ms,
|
||||
memory=memory_dict,
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
@@ -14,7 +14,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import datetime, timezone
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Auth routes: register, login, refresh, me, OAuth social login.
|
||||
"""Auth routes: register, login, refresh, me, OAuth social login, onboarding.
|
||||
|
||||
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
||||
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
||||
@@ -12,6 +12,7 @@ OAuth (Google):
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import urllib.parse
|
||||
import uuid
|
||||
@@ -23,13 +24,15 @@ from cryptography.fernet import Fernet
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.auth.oauth_providers import GoogleOAuthProvider, generate_pkce_pair
|
||||
from app.config.settings import settings
|
||||
from app.core.llm import get_llm
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.models import OAuthAccount, RefreshToken, User
|
||||
from app.schemas import AuthTokens, UserProfile
|
||||
@@ -495,3 +498,298 @@ async def oauth_callback(
|
||||
plain_token, tokens = await _issue_refresh_token(new_user, db)
|
||||
await db.commit()
|
||||
return tokens
|
||||
|
||||
|
||||
# ── Onboarding helpers ────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _build_profile(user_id: str, email: str, db: AsyncSession) -> UserProfile:
|
||||
"""Re-fetch and return a full UserProfile (reuses get_current_user logic)."""
|
||||
|
||||
# We can't call the FastAPI dependency directly, but we can replicate
|
||||
# the core logic inline. Instead, we just re-query the same way.
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||
tier: str = result.scalar_one_or_none() or default_tier
|
||||
|
||||
user_result = await db.execute(
|
||||
select(
|
||||
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
||||
User.password_hash,
|
||||
).where(User.id == user_id)
|
||||
)
|
||||
user_row = user_result.one_or_none()
|
||||
|
||||
onboarding_ms: int | None = None
|
||||
if user_row and user_row.onboarding_completed_at is not None:
|
||||
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
|
||||
|
||||
memory_dict: dict[str, str] = {}
|
||||
try:
|
||||
mw = MemoryMiddleware(db)
|
||||
blocks = await mw.list_core_blocks(user_id)
|
||||
memory_dict = {b["label"]: b["value"] for b in blocks}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return UserProfile(
|
||||
id=user_id,
|
||||
email=email,
|
||||
name=user_row.name if user_row else None,
|
||||
surname=user_row.surname if user_row else None,
|
||||
avatar_url=user_row.avatar_url if user_row else None,
|
||||
has_password=bool(user_row.password_hash) if user_row else False,
|
||||
tier=tier,
|
||||
onboarding_completed_at=onboarding_ms,
|
||||
memory=memory_dict,
|
||||
)
|
||||
|
||||
|
||||
# ── Onboarding routes ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _UpdateMemoryRequest(BaseModel):
|
||||
memory: dict[str, str] = Field(default_factory=dict)
|
||||
mark_onboarded: bool = False
|
||||
|
||||
|
||||
@router.put("/me/memory", response_model=UserProfile)
|
||||
async def update_memory(
|
||||
body: _UpdateMemoryRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Update core memory key/value pairs and optionally mark onboarding complete."""
|
||||
mw = MemoryMiddleware(db)
|
||||
for key, value in body.memory.items():
|
||||
await mw.update_core(current_user.id, key, value)
|
||||
if body.mark_onboarded:
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
user.onboarding_completed_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
return await _build_profile(current_user.id, current_user.email, db)
|
||||
|
||||
|
||||
@router.post("/me/onboarding/reset")
|
||||
async def reset_onboarding(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Reset onboarding so the wizard runs again on next login."""
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
user.onboarding_completed_at = None
|
||||
await db.commit()
|
||||
return {"status": "reset"}
|
||||
|
||||
|
||||
class _NormalizeRequest(BaseModel):
|
||||
inputs: dict[str, str]
|
||||
|
||||
|
||||
class _NormalizeResponse(BaseModel):
|
||||
normalized: dict[str, str]
|
||||
|
||||
|
||||
@router.post("/onboarding/normalize", response_model=_NormalizeResponse)
|
||||
async def normalize_onboarding(
|
||||
body: _NormalizeRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> _NormalizeResponse:
|
||||
"""One-shot LLM normalization for free-text onboarding answers."""
|
||||
if not body.inputs:
|
||||
return _NormalizeResponse(normalized={})
|
||||
try:
|
||||
llm = get_llm(model="gpt-4o-mini", temperature=0)
|
||||
prompt = (
|
||||
"You normalize user onboarding answers into clean, ≤3-word canonical labels.\n"
|
||||
"Return a JSON object with the same keys and normalized values.\n"
|
||||
"Examples: 'i build websites' → 'Web Developer', 'tech-ish stuff' → 'Technology'\n"
|
||||
f"Input: {json.dumps(body.inputs)}"
|
||||
)
|
||||
response = await llm.ainvoke(
|
||||
[
|
||||
{"role": "system", "content": "You normalize user inputs. Return JSON only."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
)
|
||||
normalized = json.loads(response.content)
|
||||
return _NormalizeResponse(normalized=normalized)
|
||||
except Exception:
|
||||
# LLM failure must never block onboarding — return inputs unchanged
|
||||
return _NormalizeResponse(normalized=body.inputs)
|
||||
|
||||
|
||||
# ── Password management ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class _ChangePasswordRequest(BaseModel):
|
||||
current_password: str = Field(min_length=1)
|
||||
new_password: str = Field(min_length=8)
|
||||
|
||||
|
||||
@router.put("/me/password", status_code=status.HTTP_200_OK)
|
||||
async def change_password(
|
||||
body: _ChangePasswordRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Change the authenticated user's password.
|
||||
|
||||
Requires the current password for verification.
|
||||
Returns 400 for social-only users (no password set).
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
|
||||
if user.password_hash is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
"This account uses social login and has no password to change",
|
||||
)
|
||||
|
||||
if not _verify_password(body.current_password, user.password_hash):
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Current password is incorrect")
|
||||
|
||||
user.password_hash = _hash_password(body.new_password)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── OAuth account management ─────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/me/oauth-accounts", response_model=list[dict])
|
||||
async def list_oauth_accounts(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[dict]:
|
||||
"""List all OAuth providers linked to the authenticated user."""
|
||||
result = await db.execute(
|
||||
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
||||
)
|
||||
accounts = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"provider": a.provider,
|
||||
"provider_email": a.provider_email,
|
||||
"created_at": int(a.created_at.timestamp() * 1000),
|
||||
}
|
||||
for a in accounts
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/me/oauth-accounts/{provider}", status_code=status.HTTP_200_OK)
|
||||
async def unlink_oauth_account(
|
||||
provider: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Unlink an OAuth provider from the authenticated user.
|
||||
|
||||
Refuses if the user has no password and this is their only login method.
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
|
||||
oauth_result = await db.execute(
|
||||
select(OAuthAccount).where(
|
||||
OAuthAccount.user_id == current_user.id,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
account = oauth_result.scalar_one_or_none()
|
||||
if account is None:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, f"No linked {provider} account found")
|
||||
|
||||
# Safety: don't let users lock themselves out.
|
||||
all_oauth = await db.execute(
|
||||
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
||||
)
|
||||
oauth_count = len(all_oauth.scalars().all())
|
||||
|
||||
if user.password_hash is None and oauth_count <= 1:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
"Cannot unlink the only login method. Set a password first.",
|
||||
)
|
||||
|
||||
await db.delete(account)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── Avatar update ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _UpdateAvatarRequest(BaseModel):
|
||||
avatar_url: str = Field(min_length=1)
|
||||
|
||||
|
||||
@router.put("/me/avatar", response_model=UserProfile)
|
||||
async def update_avatar(
|
||||
body: _UpdateAvatarRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Update the authenticated user's avatar URL.
|
||||
|
||||
Accepts {"avatar_url": "https://..."} — the client uploads the image
|
||||
to its own storage and passes the resulting URL here.
|
||||
"""
|
||||
if not body.avatar_url.startswith(("https://", "http://", "data:image/")):
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid avatar URL")
|
||||
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
user.avatar_url = body.avatar_url
|
||||
await db.commit()
|
||||
|
||||
return await _build_profile(current_user.id, current_user.email, db)
|
||||
|
||||
|
||||
# ── Account deletion ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.delete("/me", status_code=status.HTTP_200_OK)
|
||||
async def delete_account(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Permanently delete the authenticated user's account.
|
||||
|
||||
Cascades: refresh tokens, OAuth accounts, subscription, and all memory
|
||||
rows are deleted via SQLAlchemy relationship cascades. Stripe subscription
|
||||
is cancelled if active.
|
||||
"""
|
||||
# Cancel Stripe subscription if present.
|
||||
try:
|
||||
from app.billing.stripe_service import stripe_service # noqa: PLC0415
|
||||
await stripe_service.cancel_subscription(current_user.id, db)
|
||||
except HTTPException:
|
||||
pass # No subscription — that's fine
|
||||
|
||||
# Delete all memory rows (core, associative, episodic, proactive).
|
||||
try:
|
||||
from app.models import ( # noqa: PLC0415
|
||||
MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive,
|
||||
)
|
||||
for model in (MemoryCore, MemoryAssociative, MemoryEpisodic, MemoryProactive):
|
||||
await db.execute(
|
||||
model.__table__.delete().where(model.user_id == current_user.id)
|
||||
)
|
||||
except Exception:
|
||||
pass # Non-critical — cascade on User will handle most
|
||||
|
||||
# Delete the user row — cascades handle refresh_tokens, oauth_accounts, subscription.
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
await db.delete(user)
|
||||
await db.commit()
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
@@ -83,3 +83,16 @@ async def cancel_subscription(
|
||||
"""Cancel the active subscription."""
|
||||
await stripe_service.cancel_subscription(current_user.id, db)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/invoices", response_model=list[dict])
|
||||
async def list_invoices(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return billing history (invoices) from Stripe.
|
||||
|
||||
Returns an empty list when Stripe is not configured.
|
||||
"""
|
||||
invoices = await stripe_service.list_invoices(current_user.id, db)
|
||||
return invoices
|
||||
|
||||
@@ -200,6 +200,45 @@ class StripeService:
|
||||
sub.status = "canceled"
|
||||
await db.commit()
|
||||
|
||||
async def list_invoices(
|
||||
self, user_id: str, db: AsyncSession, limit: int = 24
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return recent invoices for the user from Stripe.
|
||||
|
||||
Returns an empty list when Stripe is not configured or the user has
|
||||
no ``stripe_customer_id``.
|
||||
"""
|
||||
if not self._configured():
|
||||
return []
|
||||
|
||||
from app.models import User # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(User.stripe_customer_id).where(User.id == user_id)
|
||||
)
|
||||
customer_id = result.scalar_one_or_none()
|
||||
if not customer_id:
|
||||
return []
|
||||
|
||||
try:
|
||||
s = self._client()
|
||||
invoices = s.Invoice.list(customer=customer_id, limit=limit)
|
||||
return [
|
||||
{
|
||||
"id": inv.id,
|
||||
"amount_due": inv.amount_due,
|
||||
"amount_paid": inv.amount_paid,
|
||||
"currency": inv.currency,
|
||||
"status": inv.status,
|
||||
"created": inv.created * 1000, # epoch ms
|
||||
"invoice_url": inv.hosted_invoice_url,
|
||||
"invoice_pdf": inv.invoice_pdf,
|
||||
}
|
||||
for inv in invoices.auto_paging_iter()
|
||||
]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
# ── Private DB helpers ───────────────────────────────────────────────
|
||||
|
||||
async def _upsert_subscription(
|
||||
|
||||
@@ -57,7 +57,13 @@ class Settings(BaseSettings):
|
||||
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
||||
OAUTH_ENCRYPTION_KEY: str = ""
|
||||
|
||||
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
||||
CORS_ORIGINS: list[str] = [
|
||||
"app://.",
|
||||
"http://localhost:3000",
|
||||
"http://localhost:5173",
|
||||
"http://localhost:4173", # Vite preview (web SPA)
|
||||
"https://app.adiuvai.com", # Production web portal
|
||||
]
|
||||
|
||||
LANGFUSE_SECRET_KEY: str = ""
|
||||
LANGFUSE_PUBLIC_KEY: str = ""
|
||||
|
||||
@@ -30,7 +30,6 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -27,6 +27,34 @@ logger = logging.getLogger(__name__)
|
||||
FloatingDomainType = Literal["task", "timeline", "project", "node"]
|
||||
FloatingDomainSection = Literal["task", "timeline", "note"]
|
||||
|
||||
# Mapping of core-memory language values to natural-language names for prompts.
|
||||
_LANGUAGE_NAMES: dict[str, str] = {
|
||||
"en": "English", "it": "Italian", "es": "Spanish",
|
||||
"fr": "French", "de": "German",
|
||||
"english": "English", "italian": "Italian", "italiano": "Italian",
|
||||
"spanish": "Spanish", "español": "Spanish",
|
||||
"french": "French", "français": "French",
|
||||
"german": "German", "deutsch": "German",
|
||||
}
|
||||
|
||||
|
||||
def _language_instruction(context: dict[str, Any]) -> str:
|
||||
"""Return a system-prompt suffix that tells the LLM to respond in the user's language.
|
||||
|
||||
Returns an empty string when the language is English or unknown — saves tokens.
|
||||
"""
|
||||
core = context.get("core_memory") or {}
|
||||
raw = (core.get("language") or "").strip().lower()
|
||||
if not raw:
|
||||
return ""
|
||||
lang = _LANGUAGE_NAMES.get(raw, raw.title()) # best-effort capitalisation
|
||||
if lang.lower() == "english":
|
||||
return ""
|
||||
return (
|
||||
f"\n\nIMPORTANT: Always respond in {lang}. "
|
||||
f"All your output text must be written in {lang}."
|
||||
)
|
||||
|
||||
_HOME_SYSTEM_PROMPT = (
|
||||
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||
"Always use tools for factual data retrieval before answering. "
|
||||
@@ -876,6 +904,7 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||
"home_system", _HOME_SYSTEM_PROMPT
|
||||
)
|
||||
system_prompt += _language_instruction(context)
|
||||
response = await _run_single_agent(
|
||||
user_id=user_id,
|
||||
system_prompt=system_prompt,
|
||||
@@ -893,6 +922,7 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t
|
||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||
"floating_system", _FLOATING_SYSTEM_PROMPT
|
||||
)
|
||||
system_prompt += _language_instruction(context)
|
||||
response = await _run_single_agent(
|
||||
user_id=user_id,
|
||||
system_prompt=system_prompt,
|
||||
@@ -916,6 +946,7 @@ async def run_home_stream(
|
||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||
"home_system", _HOME_SYSTEM_PROMPT
|
||||
)
|
||||
system_prompt += _language_instruction(context)
|
||||
text_chunks: list[str] = []
|
||||
async for event in _run_single_agent_stream(
|
||||
user_id=user_id,
|
||||
@@ -948,6 +979,7 @@ async def run_floating_stream(
|
||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||
"floating_system", _FLOATING_SYSTEM_PROMPT
|
||||
)
|
||||
system_prompt += _language_instruction(context)
|
||||
sanitizer = _FloatingStreamSanitizer()
|
||||
emitted_sanitized = False
|
||||
raw_chunks: list[str] = []
|
||||
|
||||
@@ -25,7 +25,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -70,7 +70,7 @@ class User(Base):
|
||||
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
password_hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
avatar_url: Mapped[str | None] = mapped_column(String(2048), nullable=True)
|
||||
avatar_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
||||
@@ -79,6 +79,9 @@ class User(Base):
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
onboarding_completed_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, default=None
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
@@ -31,6 +31,15 @@ class UserProfile(BaseModel):
|
||||
surname: str | None = None
|
||||
tier: BillingTier
|
||||
avatar_url: str | None = None
|
||||
has_password: bool = True
|
||||
onboarding_completed_at: int | None = None # epoch ms, null = not onboarded
|
||||
memory: dict[str, str] = Field(default_factory=dict) # decrypted core memory k/v
|
||||
|
||||
|
||||
class OAuthAccountInfo(BaseModel):
|
||||
provider: str
|
||||
provider_email: str | None = None
|
||||
created_at: int # epoch ms
|
||||
|
||||
|
||||
# ── Chat ─────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -28,7 +28,6 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.core.agent_runner import (
|
||||
_extract_items_from_content,
|
||||
@@ -597,7 +596,7 @@ async def test_run_cloud_agent_provider_fetch_error():
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_refreshed_token_persisted():
|
||||
"""When the provider refreshes its token, the new ciphertext is written to DB."""
|
||||
from app.integrations import EmailMessage, encrypt_token
|
||||
from app.integrations import encrypt_token
|
||||
from cryptography.fernet import Fernet as _Fernet
|
||||
|
||||
fernet_key = _Fernet.generate_key().decode()
|
||||
|
||||
@@ -40,7 +40,6 @@ from app.core.agent_runner import (
|
||||
_format_projects,
|
||||
_get_extraction_rules,
|
||||
_get_no_match_behavior,
|
||||
_is_overdue,
|
||||
run_local_agent,
|
||||
)
|
||||
from app.core.device_manager import DeviceConnectionManager
|
||||
|
||||
@@ -21,7 +21,6 @@ import time
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
@@ -18,13 +18,12 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.core.device_manager import DeviceConnection, DeviceConnectionManager
|
||||
from app.core.device_manager import DeviceConnectionManager
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import AgentRunLog
|
||||
from tests.conftest import TEST_USER_IDS, auth_header, make_jwt
|
||||
from tests.conftest import TEST_USER_IDS, make_jwt
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
|
||||
@@ -40,11 +40,9 @@ Coverage:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.memory_middleware import MemoryMiddleware, _PROACTIVE_CONFIDENCE_THRESHOLD
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import (
|
||||
|
||||
@@ -7,10 +7,9 @@ column is stored as JSON in tests (SQLite-compatible).
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from __future__ import annotations
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from app.core.preprocessors import detect_content_type, preprocess
|
||||
|
||||
Reference in New Issue
Block a user