Compare commits
9 Commits
7253f6fe72
...
feature/ba
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e668e3fd20 | ||
|
|
7ccdad431f | ||
|
|
4073863dc6 | ||
|
|
a85f8fde29 | ||
|
|
90500a3462 | ||
|
|
c1a8ac7669 | ||
|
|
c510cbaae5 | ||
|
|
ce139bbac3 | ||
|
|
3cf067faea |
35
.env.example
35
.env.example
@@ -13,10 +13,45 @@ JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||
# ── LLM ───────────────────────────────────────────────────────────────────────
|
||||
# LiteLLM model identifiers — change to swap providers without code changes.
|
||||
# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3
|
||||
#
|
||||
# API keys — only the key(s) matching your chosen provider(s) are required.
|
||||
# The correct key is picked automatically from the model prefix (e.g.
|
||||
# "anthropic/..." → ANTHROPIC_API_KEY, "gemini/..." → GOOGLE_API_KEY).
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
GOOGLE_API_KEY=
|
||||
CEREBRAS_API_KEY=
|
||||
|
||||
# Default model used by any agent that does not have a specific override below.
|
||||
LLM_MODEL=gpt-5-mini
|
||||
LLM_EMBED_MODEL=text-embedding-3-small
|
||||
|
||||
# GitHub Copilot — leave empty to use the LiteLLM default token directory.
|
||||
# In Docker, point this to a named-volume path so tokens survive restarts.
|
||||
# GITHUB_COPILOT_TOKEN_DIR=
|
||||
|
||||
# ── Per-agent model overrides ─────────────────────────────────────────────────
|
||||
# Leave a value empty to fall back to LLM_MODEL.
|
||||
# Each agent resolves its API key from the model prefix automatically.
|
||||
#
|
||||
# Intent classifier — routes user messages to the right domain agent.
|
||||
# A small/fast model (e.g. gpt-4o-mini) is usually sufficient here.
|
||||
LLM_MODEL_CLASSIFIER=
|
||||
|
||||
# Home-agent — handles chat from the home screen (all tools available).
|
||||
LLM_MODEL_HOME_AGENT=
|
||||
|
||||
# Floating-agent — handles contextual chat triggered from a task/project/note.
|
||||
LLM_MODEL_FLOATING_AGENT=
|
||||
|
||||
# Unified-processor — processes local directory files (local agent runner).
|
||||
LLM_MODEL_UNIFIED_PROCESSOR=
|
||||
|
||||
# Cloud-processor — fetches and processes data from cloud connectors.
|
||||
LLM_MODEL_CLOUD_PROCESSOR=
|
||||
|
||||
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
|
||||
LLM_MODEL_SETUP_AGENT=
|
||||
|
||||
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
||||
STRIPE_SECRET_KEY=
|
||||
|
||||
@@ -16,7 +16,7 @@ import re
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
# Alembic Config object (gives access to alembic.ini values).
|
||||
|
||||
56
alembic/versions/b4c0d1e2f3a4_add_oauth_and_avatar.py
Normal file
56
alembic/versions/b4c0d1e2f3a4_add_oauth_and_avatar.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Add oauth_accounts table, nullable password_hash, avatar_url to users.
|
||||
|
||||
Revision ID: b4c0d1e2f3a4
|
||||
Revises: a3b9c0d1e2f3
|
||||
Create Date: 2026-04-10 00:00:00.000000
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b4c0d1e2f3a4"
|
||||
down_revision: Union[str, None] = "a3b9c0d1e2f3"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── users: make password_hash nullable (social users have no password) ──
|
||||
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=True)
|
||||
|
||||
# ── users: add avatar_url ─────────────────────────────────────────────
|
||||
op.add_column("users", sa.Column("avatar_url", sa.String(2048), nullable=True))
|
||||
|
||||
# ── oauth_accounts ────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"oauth_accounts",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("provider", sa.String(50), nullable=False),
|
||||
sa.Column("provider_user_id", sa.String(255), nullable=False),
|
||||
sa.Column("provider_email", sa.String(255), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"),
|
||||
)
|
||||
op.create_index("ix_oauth_accounts_user_id", "oauth_accounts", ["user_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_oauth_accounts_user_id", table_name="oauth_accounts")
|
||||
op.drop_table("oauth_accounts")
|
||||
op.drop_column("users", "avatar_url")
|
||||
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=False)
|
||||
31
alembic/versions/c5d1e2f3a4b5_add_onboarding_completed_at.py
Normal file
31
alembic/versions/c5d1e2f3a4b5_add_onboarding_completed_at.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""Add onboarding_completed_at column to users table.
|
||||
|
||||
Revision ID: c5d1e2f3a4b5
|
||||
Revises: b4c0d1e2f3a4
|
||||
Create Date: 2026-04-11 00:00:00.000000
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c5d1e2f3a4b5"
|
||||
down_revision: Union[str, None] = "b4c0d1e2f3a4"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("onboarding_completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("users", "onboarding_completed_at")
|
||||
34
alembic/versions/e04100e88ace_avatar_url_varchar_to_text.py
Normal file
34
alembic/versions/e04100e88ace_avatar_url_varchar_to_text.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""avatar_url_varchar_to_text
|
||||
|
||||
Revision ID: e04100e88ace
|
||||
Revises: c5d1e2f3a4b5
|
||||
Create Date: 2026-04-13 09:13:06.733674
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'e04100e88ace'
|
||||
down_revision: Union[str, None] = 'c5d1e2f3a4b5'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column('users', 'avatar_url',
|
||||
existing_type=sa.VARCHAR(length=2048),
|
||||
type_=sa.Text(),
|
||||
existing_nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column('users', 'avatar_url',
|
||||
existing_type=sa.Text(),
|
||||
type_=sa.VARCHAR(length=2048),
|
||||
existing_nullable=True)
|
||||
@@ -65,16 +65,39 @@ async def get_current_user(
|
||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||
tier: str = result.scalar_one_or_none() or default_tier
|
||||
|
||||
# Fetch name/surname from user row.
|
||||
# Fetch name/surname/avatar_url/onboarding_completed_at/password_hash from user row.
|
||||
user_result = await db.execute(
|
||||
select(User.name, User.surname).where(User.id == user_id)
|
||||
select(
|
||||
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
||||
User.password_hash,
|
||||
).where(User.id == user_id)
|
||||
)
|
||||
user_row = user_result.one_or_none()
|
||||
|
||||
# Convert onboarding_completed_at to epoch ms (int) or None.
|
||||
onboarding_ms: int | None = None
|
||||
if user_row and user_row.onboarding_completed_at is not None:
|
||||
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
|
||||
|
||||
# Load decrypted core memory.
|
||||
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
|
||||
|
||||
memory_dict: dict[str, str] = {}
|
||||
try:
|
||||
mw = MemoryMiddleware(db)
|
||||
blocks = await mw.list_core_blocks(user_id)
|
||||
memory_dict = {b["label"]: b["value"] for b in blocks}
|
||||
except Exception:
|
||||
pass # Non-critical — return empty memory on failure
|
||||
|
||||
return UserProfile(
|
||||
id=user_id,
|
||||
email=email,
|
||||
name=user_row.name if user_row else None,
|
||||
surname=user_row.surname if user_row else None,
|
||||
avatar_url=user_row.avatar_url if user_row else None,
|
||||
has_password=bool(user_row.password_hash) if user_row else False,
|
||||
tier=tier,
|
||||
onboarding_completed_at=onboarding_ms,
|
||||
memory=memory_dict,
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
@@ -32,9 +32,8 @@ from typing import Any
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
|
||||
from app.agents.filesystem_agent import make_directory_tools
|
||||
from app.config.settings import settings
|
||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
|
||||
from app.core.llm import get_llm
|
||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
||||
from app.core.llm import get_agent_llm, model_for_agent
|
||||
from app.schemas import AgentConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -257,15 +256,17 @@ async def _call_llm_with_tools(
|
||||
else:
|
||||
messages.append(AIMessage(content=turn["content"]))
|
||||
|
||||
llm = get_llm(model=None, temperature=0.4)
|
||||
llm = get_agent_llm("setup", temperature=0.4)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||
|
||||
_lf_ctx = langfuse_context(user_id=user_id or None, session_id=session_id or None)
|
||||
_lf_ctx.__enter__()
|
||||
|
||||
_span_ctx = (
|
||||
lf.start_as_current_observation(
|
||||
as_type="span",
|
||||
name="journey-setup",
|
||||
metadata={"user_id": user_id or None, "session_id": session_id or None},
|
||||
input=history[-1]["content"] if history else "",
|
||||
)
|
||||
if lf else None
|
||||
@@ -278,7 +279,7 @@ async def _call_llm_with_tools(
|
||||
lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="journey-setup-llm",
|
||||
model=settings.LLM_MODEL,
|
||||
model=model_for_agent("setup"),
|
||||
prompt=langfuse_prompt,
|
||||
input=messages,
|
||||
)
|
||||
@@ -287,7 +288,7 @@ async def _call_llm_with_tools(
|
||||
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
if _gen_ctx:
|
||||
_gen.update(output=_as_text(response.content), usage=extract_usage(response))
|
||||
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||
_gen_ctx.__exit__(None, None, None)
|
||||
|
||||
resp_text = _as_text(response.content)
|
||||
@@ -343,6 +344,7 @@ async def _call_llm_with_tools(
|
||||
finally:
|
||||
if _span_ctx:
|
||||
_span_ctx.__exit__(None, None, None)
|
||||
_lf_ctx.__exit__(None, None, None)
|
||||
if lf:
|
||||
lf.flush()
|
||||
|
||||
|
||||
@@ -12,8 +12,11 @@ in backend agent-config tables.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import datetime, timezone
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
@@ -177,6 +180,11 @@ async def trigger_agent_run(
|
||||
_enforce_agent_limit(current_user.tier, body.active_agents)
|
||||
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
||||
|
||||
last_run_dt = (
|
||||
datetime.fromtimestamp(body.last_run_at / 1000, tz=timezone.utc)
|
||||
if body.last_run_at
|
||||
else None
|
||||
)
|
||||
config = LocalAgentConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=current_user.id,
|
||||
@@ -184,10 +192,12 @@ async def trigger_agent_run(
|
||||
name="Local Directory Monitor",
|
||||
directory_paths=[body.directory],
|
||||
data_types=_to_data_types(body.what_to_extract),
|
||||
prompt_template=body.custom_agent_prompt,
|
||||
prompt_template=body.custom_agent_prompt or "",
|
||||
agent_config=body.agent_config,
|
||||
file_extensions=[],
|
||||
schedule_cron=body.batch_interval,
|
||||
enabled=True,
|
||||
last_run_at=last_run_dt,
|
||||
)
|
||||
|
||||
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
||||
|
||||
@@ -1,34 +1,68 @@
|
||||
"""Auth routes: register, login, refresh, me.
|
||||
"""Auth routes: register, login, refresh, me, OAuth social login, onboarding.
|
||||
|
||||
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
||||
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
||||
SHA-256 hashes so plaintext never reaches the DB.
|
||||
|
||||
OAuth (Google):
|
||||
GET /auth/oauth/{provider}/authorize — returns consent-screen URL + state
|
||||
POST /auth/oauth/{provider}/callback — exchanges code, issues JWT tokens
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Literal
|
||||
|
||||
import bcrypt
|
||||
from cryptography.fernet import Fernet
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.auth.oauth_providers import GoogleOAuthProvider, generate_pkce_pair
|
||||
from app.config.settings import settings
|
||||
from app.core.llm import get_llm
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.models import RefreshToken, User
|
||||
from app.models import OAuthAccount, RefreshToken, User
|
||||
from app.schemas import AuthTokens, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ── OAuth provider registry ───────────────────────────────────────────
|
||||
|
||||
def _get_google_provider() -> GoogleOAuthProvider:
|
||||
if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
"Google login is not configured on this server",
|
||||
)
|
||||
return GoogleOAuthProvider(
|
||||
client_id=settings.GOOGLE_AUTH_CLIENT_ID,
|
||||
client_secret=settings.GOOGLE_AUTH_CLIENT_SECRET,
|
||||
redirect_uri=settings.OAUTH_REDIRECT_URI,
|
||||
)
|
||||
|
||||
|
||||
_PROVIDERS = {"google": _get_google_provider}
|
||||
|
||||
# In-memory state store: state → (code_verifier, expires_at_epoch_s)
|
||||
# Production note: replace with Redis for multi-process deployments.
|
||||
_pending_states: dict[str, tuple[str, float]] = {}
|
||||
_STATE_TTL_SECONDS = 600 # 10 minutes
|
||||
|
||||
|
||||
# ── Internal helpers ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -231,5 +265,531 @@ async def update_profile(
|
||||
email=user.email,
|
||||
name=user.name,
|
||||
surname=user.surname,
|
||||
avatar_url=user.avatar_url,
|
||||
tier=current_user.tier,
|
||||
)
|
||||
|
||||
|
||||
# ── OAuth helpers ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _issue_refresh_token(user: User, db: AsyncSession) -> tuple[str, AuthTokens]:
|
||||
"""Create a refresh token row and return (plain_token, AuthTokens)."""
|
||||
plain_token = str(uuid.uuid4())
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return plain_token, AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
# ── OAuth request/response schemas ───────────────────────────────────
|
||||
|
||||
|
||||
class _OAuthAuthorizeResponse(BaseModel):
|
||||
url: str
|
||||
state: str
|
||||
|
||||
|
||||
class _OAuthCallbackRequest(BaseModel):
|
||||
code: str
|
||||
state: str
|
||||
|
||||
|
||||
# ── OAuth routes ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/oauth/{provider}/web-callback",
|
||||
summary="Web-facing OAuth redirect — bounces to the adiuvai:// deep link",
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def oauth_web_callback(
|
||||
provider: Literal["google"],
|
||||
code: str,
|
||||
state: str,
|
||||
) -> RedirectResponse:
|
||||
"""Google redirects here after user consent.
|
||||
|
||||
This endpoint immediately redirects to the Electron deep-link URI so the
|
||||
desktop app receives the authorization code. It is intentionally simple —
|
||||
no state validation here (the Electron app + backend callback do that).
|
||||
|
||||
Registered in Google Cloud Console as:
|
||||
http://localhost:8000/api/v1/auth/oauth/google/web-callback (dev)
|
||||
https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback (prod)
|
||||
"""
|
||||
params = urllib.parse.urlencode({"code": code, "state": state, "provider": provider})
|
||||
deep_link = f"adiuvai://oauth/callback?{params}"
|
||||
return RedirectResponse(url=deep_link, status_code=302)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/oauth/{provider}/authorize",
|
||||
response_model=_OAuthAuthorizeResponse,
|
||||
summary="Start OAuth flow — returns the provider consent-screen URL",
|
||||
)
|
||||
async def oauth_authorize(
|
||||
provider: Literal["google"],
|
||||
) -> _OAuthAuthorizeResponse:
|
||||
"""Generate a PKCE state + code_challenge and return the authorization URL.
|
||||
|
||||
The client opens this URL in the system browser. After the user grants
|
||||
consent, the provider redirects to the deep-link URI (adiuvai://oauth/callback)
|
||||
with ``code`` and ``state`` query params. The client then calls
|
||||
``POST /auth/oauth/{provider}/callback`` with those values.
|
||||
"""
|
||||
provider_factory = _PROVIDERS.get(provider)
|
||||
if provider_factory is None:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
|
||||
|
||||
oauth_provider = provider_factory()
|
||||
state = str(uuid.uuid4())
|
||||
code_verifier, code_challenge = generate_pkce_pair()
|
||||
|
||||
# Purge expired states to prevent unbounded growth.
|
||||
now = time.time()
|
||||
expired = [s for s, (_, exp) in _pending_states.items() if exp < now]
|
||||
for s in expired:
|
||||
del _pending_states[s]
|
||||
|
||||
_pending_states[state] = (code_verifier, now + _STATE_TTL_SECONDS)
|
||||
|
||||
url = oauth_provider.get_authorization_url(state=state, code_challenge=code_challenge)
|
||||
return _OAuthAuthorizeResponse(url=url, state=state)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/oauth/{provider}/callback",
|
||||
response_model=AuthTokens,
|
||||
summary="Complete OAuth flow — exchange code and issue JWT tokens",
|
||||
)
|
||||
async def oauth_callback(
|
||||
provider: Literal["google"],
|
||||
body: _OAuthCallbackRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Validate state, exchange the authorization code, and sign in (or register) the user.
|
||||
|
||||
Resolution order:
|
||||
1. ``oauth_accounts`` row match → existing user, log in.
|
||||
2. Email match + ``email_verified=True`` → link OAuth account to existing user.
|
||||
3. No match → create new user (password_hash=None, avatar from provider).
|
||||
"""
|
||||
provider_factory = _PROVIDERS.get(provider)
|
||||
if provider_factory is None:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
|
||||
|
||||
# Validate state (CSRF protection).
|
||||
now = time.time()
|
||||
entry = _pending_states.pop(body.state, None)
|
||||
if entry is None or entry[1] < now:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state")
|
||||
|
||||
code_verifier, _ = entry
|
||||
|
||||
oauth_provider = provider_factory()
|
||||
|
||||
# Exchange code for tokens.
|
||||
try:
|
||||
token_data = await oauth_provider.exchange_code(
|
||||
code=body.code,
|
||||
code_verifier=code_verifier,
|
||||
redirect_uri=settings.OAUTH_REDIRECT_URI,
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST, "Failed to exchange authorization code"
|
||||
)
|
||||
|
||||
access_token_google = token_data.get("access_token")
|
||||
if not access_token_google:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No access token in provider response")
|
||||
|
||||
# Fetch user identity.
|
||||
try:
|
||||
userinfo = await oauth_provider.get_userinfo(access_token_google)
|
||||
except Exception:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Failed to fetch user info from provider")
|
||||
|
||||
# ── Resolution order ──────────────────────────────────────────────
|
||||
|
||||
# 1. Existing OAuth link?
|
||||
oauth_result = await db.execute(
|
||||
select(OAuthAccount).where(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_user_id == userinfo.provider_user_id,
|
||||
)
|
||||
)
|
||||
oauth_account = oauth_result.scalar_one_or_none()
|
||||
|
||||
if oauth_account is not None:
|
||||
user_result = await db.execute(select(User).where(User.id == oauth_account.user_id))
|
||||
user = user_result.scalar_one()
|
||||
# Backfill avatar if the user doesn't have one yet.
|
||||
if user.avatar_url is None and userinfo.avatar_url:
|
||||
user.avatar_url = userinfo.avatar_url
|
||||
await db.commit()
|
||||
plain_token, tokens = await _issue_refresh_token(user, db)
|
||||
await db.commit()
|
||||
return tokens
|
||||
|
||||
# 2. Email match with a verified Google email → link accounts.
|
||||
if userinfo.email_verified:
|
||||
email_result = await db.execute(select(User).where(User.email == userinfo.email))
|
||||
existing_user = email_result.scalar_one_or_none()
|
||||
|
||||
if existing_user is not None:
|
||||
new_link = OAuthAccount(
|
||||
user_id=existing_user.id,
|
||||
provider=provider,
|
||||
provider_user_id=userinfo.provider_user_id,
|
||||
provider_email=userinfo.email,
|
||||
)
|
||||
db.add(new_link)
|
||||
if existing_user.avatar_url is None and userinfo.avatar_url:
|
||||
existing_user.avatar_url = userinfo.avatar_url
|
||||
plain_token, tokens = await _issue_refresh_token(existing_user, db)
|
||||
await db.commit()
|
||||
return tokens
|
||||
|
||||
# Guard: if the email is already taken but we couldn't auto-link (e.g.
|
||||
# email_verified=False), refuse with 409 instead of hitting a DB constraint.
|
||||
if not userinfo.email_verified:
|
||||
conflict = await db.execute(select(User).where(User.email == userinfo.email))
|
||||
if conflict.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status.HTTP_409_CONFLICT,
|
||||
"An account with this email already exists. "
|
||||
"Please sign in with your password.",
|
||||
)
|
||||
|
||||
# 3. New user — social-only account (no password).
|
||||
new_user = User(
|
||||
id=str(uuid.uuid4()),
|
||||
email=userinfo.email,
|
||||
name=userinfo.name,
|
||||
password_hash=None,
|
||||
avatar_url=userinfo.avatar_url,
|
||||
tier="free",
|
||||
encryption_key=Fernet.generate_key().decode(),
|
||||
)
|
||||
db.add(new_user)
|
||||
await db.flush() # populate new_user.id
|
||||
|
||||
new_oauth = OAuthAccount(
|
||||
user_id=new_user.id,
|
||||
provider=provider,
|
||||
provider_user_id=userinfo.provider_user_id,
|
||||
provider_email=userinfo.email,
|
||||
)
|
||||
db.add(new_oauth)
|
||||
|
||||
plain_token, tokens = await _issue_refresh_token(new_user, db)
|
||||
await db.commit()
|
||||
return tokens
|
||||
|
||||
|
||||
# ── Onboarding helpers ────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _build_profile(user_id: str, email: str, db: AsyncSession) -> UserProfile:
|
||||
"""Re-fetch and return a full UserProfile (reuses get_current_user logic)."""
|
||||
|
||||
# We can't call the FastAPI dependency directly, but we can replicate
|
||||
# the core logic inline. Instead, we just re-query the same way.
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||
tier: str = result.scalar_one_or_none() or default_tier
|
||||
|
||||
user_result = await db.execute(
|
||||
select(
|
||||
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
||||
User.password_hash,
|
||||
).where(User.id == user_id)
|
||||
)
|
||||
user_row = user_result.one_or_none()
|
||||
|
||||
onboarding_ms: int | None = None
|
||||
if user_row and user_row.onboarding_completed_at is not None:
|
||||
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
|
||||
|
||||
memory_dict: dict[str, str] = {}
|
||||
try:
|
||||
mw = MemoryMiddleware(db)
|
||||
blocks = await mw.list_core_blocks(user_id)
|
||||
memory_dict = {b["label"]: b["value"] for b in blocks}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return UserProfile(
|
||||
id=user_id,
|
||||
email=email,
|
||||
name=user_row.name if user_row else None,
|
||||
surname=user_row.surname if user_row else None,
|
||||
avatar_url=user_row.avatar_url if user_row else None,
|
||||
has_password=bool(user_row.password_hash) if user_row else False,
|
||||
tier=tier,
|
||||
onboarding_completed_at=onboarding_ms,
|
||||
memory=memory_dict,
|
||||
)
|
||||
|
||||
|
||||
# ── Onboarding routes ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _UpdateMemoryRequest(BaseModel):
|
||||
memory: dict[str, str] = Field(default_factory=dict)
|
||||
mark_onboarded: bool = False
|
||||
|
||||
|
||||
@router.put("/me/memory", response_model=UserProfile)
|
||||
async def update_memory(
|
||||
body: _UpdateMemoryRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Update core memory key/value pairs and optionally mark onboarding complete."""
|
||||
mw = MemoryMiddleware(db)
|
||||
for key, value in body.memory.items():
|
||||
await mw.update_core(current_user.id, key, value)
|
||||
if body.mark_onboarded:
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
user.onboarding_completed_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
return await _build_profile(current_user.id, current_user.email, db)
|
||||
|
||||
|
||||
@router.post("/me/onboarding/reset")
|
||||
async def reset_onboarding(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Reset onboarding so the wizard runs again on next login."""
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
user.onboarding_completed_at = None
|
||||
await db.commit()
|
||||
return {"status": "reset"}
|
||||
|
||||
|
||||
class _NormalizeRequest(BaseModel):
|
||||
inputs: dict[str, str]
|
||||
|
||||
|
||||
class _NormalizeResponse(BaseModel):
|
||||
normalized: dict[str, str]
|
||||
|
||||
|
||||
@router.post("/onboarding/normalize", response_model=_NormalizeResponse)
|
||||
async def normalize_onboarding(
|
||||
body: _NormalizeRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> _NormalizeResponse:
|
||||
"""One-shot LLM normalization for free-text onboarding answers."""
|
||||
if not body.inputs:
|
||||
return _NormalizeResponse(normalized={})
|
||||
try:
|
||||
llm = get_llm(model="gpt-4o-mini", temperature=0)
|
||||
prompt = (
|
||||
"You normalize user onboarding answers into clean, ≤3-word canonical labels.\n"
|
||||
"Return a JSON object with the same keys and normalized values.\n"
|
||||
"Examples: 'i build websites' → 'Web Developer', 'tech-ish stuff' → 'Technology'\n"
|
||||
f"Input: {json.dumps(body.inputs)}"
|
||||
)
|
||||
response = await llm.ainvoke(
|
||||
[
|
||||
{"role": "system", "content": "You normalize user inputs. Return JSON only."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
)
|
||||
normalized = json.loads(response.content)
|
||||
return _NormalizeResponse(normalized=normalized)
|
||||
except Exception:
|
||||
# LLM failure must never block onboarding — return inputs unchanged
|
||||
return _NormalizeResponse(normalized=body.inputs)
|
||||
|
||||
|
||||
# ── Password management ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class _ChangePasswordRequest(BaseModel):
|
||||
current_password: str = Field(min_length=1)
|
||||
new_password: str = Field(min_length=8)
|
||||
|
||||
|
||||
@router.put("/me/password", status_code=status.HTTP_200_OK)
|
||||
async def change_password(
|
||||
body: _ChangePasswordRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Change the authenticated user's password.
|
||||
|
||||
Requires the current password for verification.
|
||||
Returns 400 for social-only users (no password set).
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
|
||||
if user.password_hash is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
"This account uses social login and has no password to change",
|
||||
)
|
||||
|
||||
if not _verify_password(body.current_password, user.password_hash):
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Current password is incorrect")
|
||||
|
||||
user.password_hash = _hash_password(body.new_password)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── OAuth account management ─────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/me/oauth-accounts", response_model=list[dict])
|
||||
async def list_oauth_accounts(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[dict]:
|
||||
"""List all OAuth providers linked to the authenticated user."""
|
||||
result = await db.execute(
|
||||
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
||||
)
|
||||
accounts = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"provider": a.provider,
|
||||
"provider_email": a.provider_email,
|
||||
"created_at": int(a.created_at.timestamp() * 1000),
|
||||
}
|
||||
for a in accounts
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/me/oauth-accounts/{provider}", status_code=status.HTTP_200_OK)
|
||||
async def unlink_oauth_account(
|
||||
provider: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Unlink an OAuth provider from the authenticated user.
|
||||
|
||||
Refuses if the user has no password and this is their only login method.
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
|
||||
oauth_result = await db.execute(
|
||||
select(OAuthAccount).where(
|
||||
OAuthAccount.user_id == current_user.id,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
account = oauth_result.scalar_one_or_none()
|
||||
if account is None:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, f"No linked {provider} account found")
|
||||
|
||||
# Safety: don't let users lock themselves out.
|
||||
all_oauth = await db.execute(
|
||||
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
||||
)
|
||||
oauth_count = len(all_oauth.scalars().all())
|
||||
|
||||
if user.password_hash is None and oauth_count <= 1:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
"Cannot unlink the only login method. Set a password first.",
|
||||
)
|
||||
|
||||
await db.delete(account)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── Avatar update ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _UpdateAvatarRequest(BaseModel):
|
||||
avatar_url: str = Field(min_length=1)
|
||||
|
||||
|
||||
@router.put("/me/avatar", response_model=UserProfile)
|
||||
async def update_avatar(
|
||||
body: _UpdateAvatarRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Update the authenticated user's avatar URL.
|
||||
|
||||
Accepts {"avatar_url": "https://..."} — the client uploads the image
|
||||
to its own storage and passes the resulting URL here.
|
||||
"""
|
||||
if not body.avatar_url.startswith(("https://", "http://", "data:image/")):
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid avatar URL")
|
||||
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
user.avatar_url = body.avatar_url
|
||||
await db.commit()
|
||||
|
||||
return await _build_profile(current_user.id, current_user.email, db)
|
||||
|
||||
|
||||
# ── Account deletion ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.delete("/me", status_code=status.HTTP_200_OK)
|
||||
async def delete_account(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Permanently delete the authenticated user's account.
|
||||
|
||||
Cascades: refresh tokens, OAuth accounts, subscription, and all memory
|
||||
rows are deleted via SQLAlchemy relationship cascades. Stripe subscription
|
||||
is cancelled if active.
|
||||
"""
|
||||
# Cancel Stripe subscription if present.
|
||||
try:
|
||||
from app.billing.stripe_service import stripe_service # noqa: PLC0415
|
||||
await stripe_service.cancel_subscription(current_user.id, db)
|
||||
except HTTPException:
|
||||
pass # No subscription — that's fine
|
||||
|
||||
# Delete all memory rows (core, associative, episodic, proactive).
|
||||
try:
|
||||
from app.models import ( # noqa: PLC0415
|
||||
MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive,
|
||||
)
|
||||
for model in (MemoryCore, MemoryAssociative, MemoryEpisodic, MemoryProactive):
|
||||
await db.execute(
|
||||
model.__table__.delete().where(model.user_id == current_user.id)
|
||||
)
|
||||
except Exception:
|
||||
pass # Non-critical — cascade on User will handle most
|
||||
|
||||
# Delete the user row — cascades handle refresh_tokens, oauth_accounts, subscription.
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
await db.delete(user)
|
||||
await db.commit()
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
@@ -83,3 +83,16 @@ async def cancel_subscription(
|
||||
"""Cancel the active subscription."""
|
||||
await stripe_service.cancel_subscription(current_user.id, db)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/invoices", response_model=list[dict])
|
||||
async def list_invoices(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return billing history (invoices) from Stripe.
|
||||
|
||||
Returns an empty list when Stripe is not configured.
|
||||
"""
|
||||
invoices = await stripe_service.list_invoices(current_user.id, db)
|
||||
return invoices
|
||||
|
||||
1
app/auth/__init__.py
Normal file
1
app/auth/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"OAuth provider abstractions and utilities."
|
||||
135
app/auth/oauth_providers.py
Normal file
135
app/auth/oauth_providers.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""OAuth 2.0 + PKCE provider abstractions.
|
||||
|
||||
Each provider implements a three-step flow designed for a desktop (public) client:
|
||||
|
||||
1. get_authorization_url(state, code_challenge) → str
|
||||
Build the provider's consent-screen URL. State and code_challenge are
|
||||
generated server-side; the client opens this URL in the system browser.
|
||||
|
||||
2. exchange_code(code, code_verifier, redirect_uri) → dict
|
||||
Exchange the short-lived authorization code for an access token.
|
||||
The code_verifier proves ownership of the PKCE challenge.
|
||||
|
||||
3. get_userinfo(access_token) → OAuthUserInfo
|
||||
Fetch the canonical user identity from the provider.
|
||||
|
||||
Currently supported providers:
|
||||
- GoogleOAuthProvider (scope: openid email profile)
|
||||
|
||||
Adding a new provider:
|
||||
- Implement the three methods above.
|
||||
- Register in _PROVIDERS inside routes/auth.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
# ── Data transfer objects ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthUserInfo:
|
||||
"""Normalized user identity returned by any provider."""
|
||||
|
||||
provider_user_id: str
|
||||
email: str
|
||||
email_verified: bool
|
||||
avatar_url: str | None
|
||||
name: str | None
|
||||
|
||||
|
||||
# ── PKCE helpers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def generate_pkce_pair() -> tuple[str, str]:
|
||||
"""Generate a (code_verifier, code_challenge) pair for PKCE S256.
|
||||
|
||||
The code_verifier is a random 32-byte URL-safe base64 string.
|
||||
The code_challenge is SHA-256(code_verifier) base64url-encoded (no padding).
|
||||
"""
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode()
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
return code_verifier, code_challenge
|
||||
|
||||
|
||||
# ── Google provider ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class GoogleOAuthProvider:
|
||||
"""Google OAuth 2.0 provider (openid email profile scope).
|
||||
|
||||
Uses Google's standard authorization endpoint with PKCE S256.
|
||||
Does NOT use google-auth-oauthlib to keep the flow generic and async.
|
||||
"""
|
||||
|
||||
name = "google"
|
||||
|
||||
_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
_USERINFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str) -> None:
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self, state: str, code_challenge: str) -> str:
|
||||
"""Build the Google consent-screen URL."""
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": "openid email profile",
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"access_type": "offline",
|
||||
"prompt": "select_account",
|
||||
}
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
async def exchange_code(
|
||||
self, code: str, code_verifier: str, redirect_uri: str
|
||||
) -> dict:
|
||||
"""Exchange authorization code for an access token."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self._TOKEN_URL,
|
||||
data={
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"code": code,
|
||||
"code_verifier": code_verifier,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def get_userinfo(self, access_token: str) -> OAuthUserInfo:
|
||||
"""Fetch the authenticated user's identity from Google."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
self._USERINFO_URL,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
return OAuthUserInfo(
|
||||
provider_user_id=data["sub"],
|
||||
email=data["email"],
|
||||
email_verified=data.get("email_verified", False),
|
||||
avatar_url=data.get("picture"),
|
||||
name=data.get("name"),
|
||||
)
|
||||
@@ -200,6 +200,45 @@ class StripeService:
|
||||
sub.status = "canceled"
|
||||
await db.commit()
|
||||
|
||||
async def list_invoices(
|
||||
self, user_id: str, db: AsyncSession, limit: int = 24
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return recent invoices for the user from Stripe.
|
||||
|
||||
Returns an empty list when Stripe is not configured or the user has
|
||||
no ``stripe_customer_id``.
|
||||
"""
|
||||
if not self._configured():
|
||||
return []
|
||||
|
||||
from app.models import User # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(User.stripe_customer_id).where(User.id == user_id)
|
||||
)
|
||||
customer_id = result.scalar_one_or_none()
|
||||
if not customer_id:
|
||||
return []
|
||||
|
||||
try:
|
||||
s = self._client()
|
||||
invoices = s.Invoice.list(customer=customer_id, limit=limit)
|
||||
return [
|
||||
{
|
||||
"id": inv.id,
|
||||
"amount_due": inv.amount_due,
|
||||
"amount_paid": inv.amount_paid,
|
||||
"currency": inv.currency,
|
||||
"status": inv.status,
|
||||
"created": inv.created * 1000, # epoch ms
|
||||
"invoice_url": inv.hosted_invoice_url,
|
||||
"invoice_pdf": inv.invoice_pdf,
|
||||
}
|
||||
for inv in invoices.auto_paging_iter()
|
||||
]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
# ── Private DB helpers ───────────────────────────────────────────────
|
||||
|
||||
async def _upsert_subscription(
|
||||
|
||||
@@ -20,6 +20,14 @@ class Settings(BaseSettings):
|
||||
LLM_MODEL: str = "gpt-4o"
|
||||
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||
|
||||
# Per-agent model overrides. Leave empty to fall back to LLM_MODEL.
|
||||
LLM_MODEL_CLASSIFIER: str = "" # _infer_floating_domain (intent routing)
|
||||
LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream)
|
||||
LLM_MODEL_FLOATING_AGENT: str = "" # floating-agent (contextual chat)
|
||||
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
|
||||
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
||||
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
||||
|
||||
# GitHub Copilot OAuth token storage directory.
|
||||
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
|
||||
@@ -33,12 +41,29 @@ class Settings(BaseSettings):
|
||||
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
|
||||
MS_TENANT_ID: str = "common"
|
||||
|
||||
# Google Login OAuth credentials — scope: openid email profile.
|
||||
# Separate from GMAIL_CLIENT_ID/SECRET (which uses gmail.readonly scope).
|
||||
GOOGLE_AUTH_CLIENT_ID: str = ""
|
||||
GOOGLE_AUTH_CLIENT_SECRET: str = ""
|
||||
# The redirect URI registered in Google Cloud Console.
|
||||
# Google redirects here after consent; this backend route then bounces to
|
||||
# the adiuvai:// deep link so the Electron app receives the code.
|
||||
# Dev: http://localhost:8000/api/v1/auth/oauth/google/web-callback
|
||||
# Prod: https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback
|
||||
OAUTH_REDIRECT_URI: str = "http://localhost:8000/api/v1/auth/oauth/google/web-callback"
|
||||
|
||||
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
||||
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
||||
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
||||
OAUTH_ENCRYPTION_KEY: str = ""
|
||||
|
||||
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
||||
CORS_ORIGINS: list[str] = [
|
||||
"app://.",
|
||||
"http://localhost:3000",
|
||||
"http://localhost:5173",
|
||||
"http://localhost:4173", # Vite preview (web SPA)
|
||||
"https://app.adiuvai.com", # Production web portal
|
||||
]
|
||||
|
||||
LANGFUSE_SECRET_KEY: str = ""
|
||||
LANGFUSE_PUBLIC_KEY: str = ""
|
||||
|
||||
@@ -30,7 +30,6 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
@@ -43,10 +42,9 @@ from app.agents.note_agent import NOTE_TOOLS
|
||||
from app.agents.project_agent import PROJECT_TOOLS
|
||||
from app.agents.task_agent import TASK_TOOLS
|
||||
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||
from app.config.settings import settings
|
||||
from app.core.device_manager import DeviceConnectionManager
|
||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
|
||||
from app.core.llm import get_llm
|
||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
||||
from app.core.llm import get_agent_llm, model_for_agent
|
||||
from app.core.preprocessors import detect_content_type, preprocess
|
||||
from app.core.ws_context import clear_client_executor, execute_on_client, set_client_executor
|
||||
from app.db import async_session
|
||||
@@ -74,13 +72,13 @@ _MAX_PROCESSING_STEPS: int = 12
|
||||
_MAX_SCAN_DEPTH: int = 5
|
||||
|
||||
# ── Data-type to tool mapping ─────────────────────────────────────────────
|
||||
# NOTE: "projects" is intentionally excluded — project creation/assignment is
|
||||
# handled in code by the runner, never delegated to the Step 2 LLM.
|
||||
|
||||
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
|
||||
"tasks": TASK_TOOLS,
|
||||
"notes": NOTE_TOOLS,
|
||||
"timelines": TIMELINE_TOOLS,
|
||||
"timelineEvents": TIMELINE_TOOLS,
|
||||
"projects": PROJECT_TOOLS,
|
||||
}
|
||||
|
||||
# ── V2: Unified processing prompt (hot-swappable via Langfuse "unified_processing") ──
|
||||
@@ -228,6 +226,7 @@ async def _run_agent_with_tools(
|
||||
tools: list[Any],
|
||||
max_steps: int,
|
||||
user_id: str = "",
|
||||
session_id: str = "",
|
||||
langfuse_prompt: Any = None,
|
||||
agent_name: str = "batch-agent",
|
||||
_tool_calls_out: list[str] | None = None,
|
||||
@@ -238,7 +237,7 @@ async def _run_agent_with_tools(
|
||||
run is appended to it (used by the caller to count ``create_*`` calls).
|
||||
"""
|
||||
lf = get_langfuse()
|
||||
llm = get_llm()
|
||||
llm = get_agent_llm(agent_name)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
messages: list[Any] = [
|
||||
SystemMessage(content=system_prompt),
|
||||
@@ -247,6 +246,9 @@ async def _run_agent_with_tools(
|
||||
|
||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||
|
||||
_lf_ctx = langfuse_context(user_id=user_id or None, session_id=session_id or None)
|
||||
_lf_ctx.__enter__()
|
||||
|
||||
_span_ctx = (
|
||||
lf.start_as_current_observation(
|
||||
as_type="span",
|
||||
@@ -264,7 +266,7 @@ async def _run_agent_with_tools(
|
||||
lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name=f"{agent_name}-llm",
|
||||
model=settings.LLM_MODEL,
|
||||
model=model_for_agent(agent_name),
|
||||
prompt=langfuse_prompt,
|
||||
input=messages,
|
||||
)
|
||||
@@ -273,7 +275,7 @@ async def _run_agent_with_tools(
|
||||
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
if _gen_ctx:
|
||||
_gen.update(output=_as_text(response.content), usage=extract_usage(response))
|
||||
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||
_gen_ctx.__exit__(None, None, None)
|
||||
|
||||
messages.append(response)
|
||||
@@ -318,6 +320,7 @@ async def _run_agent_with_tools(
|
||||
finally:
|
||||
if _span_ctx:
|
||||
_span_ctx.__exit__(None, None, None)
|
||||
_lf_ctx.__exit__(None, None, None)
|
||||
if lf:
|
||||
lf.flush()
|
||||
|
||||
@@ -386,7 +389,8 @@ async def _scan_directories(
|
||||
for file_path in all_files:
|
||||
try:
|
||||
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
|
||||
modified_at = meta.get("modifiedAt")
|
||||
# FE sends snake_case keys on the wire (toSnakeCase transform)
|
||||
modified_at = meta.get("modified_at") or meta.get("modifiedAt")
|
||||
if modified_at is None:
|
||||
filtered.append(file_path)
|
||||
continue
|
||||
@@ -607,7 +611,6 @@ async def run_local_agent(
|
||||
|
||||
try:
|
||||
# ── Code: scan directories ───────────────────────────────────
|
||||
logger.info("agent_runner: run=%s scanning directories user=%s", run_id, user_id)
|
||||
file_paths = await _scan_directories(
|
||||
paths=config.directory_paths,
|
||||
extensions=config.file_extensions or [],
|
||||
@@ -686,6 +689,7 @@ async def run_local_agent(
|
||||
tools=processing_tools,
|
||||
max_steps=_MAX_PROCESSING_STEPS,
|
||||
user_id=user_id,
|
||||
session_id=run_id,
|
||||
langfuse_prompt=prompt_obj,
|
||||
agent_name="unified-processor",
|
||||
_tool_calls_out=file_tool_calls,
|
||||
@@ -696,6 +700,12 @@ async def run_local_agent(
|
||||
)
|
||||
items_created += file_created
|
||||
|
||||
# Refresh project list when a project was created so
|
||||
# subsequent files see it in the prompt context.
|
||||
if "create_project" in file_tool_calls:
|
||||
projects = await _fetch_projects()
|
||||
projects_block = _format_projects(projects)
|
||||
|
||||
logger.info(
|
||||
"agent_runner: run=%s file=%r created=%d result=%s",
|
||||
run_id, file_path, file_created, result_text[:200],
|
||||
@@ -911,6 +921,7 @@ async def run_cloud_agent(
|
||||
tools=processing_tools,
|
||||
max_steps=_MAX_PROCESSING_STEPS,
|
||||
user_id=user_id,
|
||||
session_id=run_id,
|
||||
langfuse_prompt=cloud_prompt_obj,
|
||||
agent_name="cloud-processor",
|
||||
)
|
||||
|
||||
@@ -16,9 +16,8 @@ from app.agents.note_agent import NOTE_TOOLS
|
||||
from app.agents.project_agent import PROJECT_TOOLS
|
||||
from app.agents.task_agent import TASK_TOOLS
|
||||
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback
|
||||
from app.core.llm import get_llm
|
||||
from app.config.settings import settings
|
||||
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
||||
from app.core.llm import get_agent_llm, model_for_agent
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
||||
from app.db import async_session
|
||||
@@ -28,6 +27,34 @@ logger = logging.getLogger(__name__)
|
||||
FloatingDomainType = Literal["task", "timeline", "project", "node"]
|
||||
FloatingDomainSection = Literal["task", "timeline", "note"]
|
||||
|
||||
# Mapping of core-memory language values to natural-language names for prompts.
|
||||
_LANGUAGE_NAMES: dict[str, str] = {
|
||||
"en": "English", "it": "Italian", "es": "Spanish",
|
||||
"fr": "French", "de": "German",
|
||||
"english": "English", "italian": "Italian", "italiano": "Italian",
|
||||
"spanish": "Spanish", "español": "Spanish",
|
||||
"french": "French", "français": "French",
|
||||
"german": "German", "deutsch": "German",
|
||||
}
|
||||
|
||||
|
||||
def _language_instruction(context: dict[str, Any]) -> str:
|
||||
"""Return a system-prompt suffix that tells the LLM to respond in the user's language.
|
||||
|
||||
Returns an empty string when the language is English or unknown — saves tokens.
|
||||
"""
|
||||
core = context.get("core_memory") or {}
|
||||
raw = (core.get("language") or "").strip().lower()
|
||||
if not raw:
|
||||
return ""
|
||||
lang = _LANGUAGE_NAMES.get(raw, raw.title()) # best-effort capitalisation
|
||||
if lang.lower() == "english":
|
||||
return ""
|
||||
return (
|
||||
f"\n\nIMPORTANT: Always respond in {lang}. "
|
||||
f"All your output text must be written in {lang}."
|
||||
)
|
||||
|
||||
_HOME_SYSTEM_PROMPT = (
|
||||
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||
"Always use tools for factual data retrieval before answering. "
|
||||
@@ -149,6 +176,15 @@ def _trace_id_from_context(context: dict[str, Any]) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _session_id_from_context(context: dict[str, Any]) -> str | None:
|
||||
debug = context.get("_debug")
|
||||
if isinstance(debug, dict):
|
||||
session_id = debug.get("session_id")
|
||||
if isinstance(session_id, str) and session_id:
|
||||
return session_id
|
||||
return None
|
||||
|
||||
|
||||
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
|
||||
sanitized = dict(context)
|
||||
sanitized.pop("_debug", None)
|
||||
@@ -537,7 +573,7 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[
|
||||
}
|
||||
|
||||
try:
|
||||
llm = get_llm()
|
||||
llm = get_agent_llm("classifier")
|
||||
classifier_messages = [
|
||||
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_PROMPT),
|
||||
HumanMessage(
|
||||
@@ -551,18 +587,25 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[
|
||||
_, classifier_prompt_obj = get_prompt_or_fallback(
|
||||
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_PROMPT
|
||||
)
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="floating-classifier",
|
||||
model=settings.LLM_MODEL,
|
||||
prompt=classifier_prompt_obj,
|
||||
input=classifier_messages,
|
||||
) as gen:
|
||||
|
||||
# Extract user/session from context for Langfuse attribution
|
||||
_debug = context.get("_debug") if isinstance(context, dict) else None
|
||||
_lf_user = (_debug or {}).get("user_id") if isinstance(_debug, dict) else None
|
||||
_lf_session = (_debug or {}).get("session_id") if isinstance(_debug, dict) else None
|
||||
|
||||
with langfuse_context(user_id=_lf_user, session_id=_lf_session):
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="floating-classifier",
|
||||
model=model_for_agent("classifier"),
|
||||
prompt=classifier_prompt_obj,
|
||||
input=classifier_messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(classifier_messages)
|
||||
gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||
else:
|
||||
response = await llm.ainvoke(classifier_messages)
|
||||
gen.update(output=_as_text(response.content), usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm.ainvoke(classifier_messages)
|
||||
parsed = _parse_json_object(_as_text(response.content))
|
||||
if parsed is not None:
|
||||
domain = _normalize_domain_payload(parsed, project_id)
|
||||
@@ -591,8 +634,9 @@ async def _run_single_agent(
|
||||
agent_name: str = "agent",
|
||||
) -> str:
|
||||
trace_id = _trace_id_from_context(context)
|
||||
session_id = _session_id_from_context(context)
|
||||
lf = get_langfuse()
|
||||
llm = get_llm()
|
||||
llm = get_agent_llm(agent_name)
|
||||
tools = _all_tools_for_user(user_id, trace_id)
|
||||
model_context = _context_for_model(context)
|
||||
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
||||
@@ -611,6 +655,9 @@ async def _run_single_agent(
|
||||
collected: list[dict[str, Any]] = []
|
||||
set_tool_result_collector(collected)
|
||||
|
||||
_lf_ctx = langfuse_context(user_id=user_id, session_id=session_id)
|
||||
_lf_ctx.__enter__()
|
||||
|
||||
_span_ctx = (
|
||||
lf.start_as_current_observation(
|
||||
as_type="span",
|
||||
@@ -628,7 +675,7 @@ async def _run_single_agent(
|
||||
lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name=f"{agent_name}-llm",
|
||||
model=settings.LLM_MODEL,
|
||||
model=model_for_agent(agent_name),
|
||||
prompt=langfuse_prompt,
|
||||
input=messages,
|
||||
)
|
||||
@@ -637,7 +684,7 @@ async def _run_single_agent(
|
||||
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
if _gen_ctx:
|
||||
_gen.update(output=_as_text(response.content), usage=extract_usage(response))
|
||||
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||
_gen_ctx.__exit__(None, None, None)
|
||||
|
||||
messages.append(response)
|
||||
@@ -699,6 +746,7 @@ async def _run_single_agent(
|
||||
clear_tool_result_collector()
|
||||
if _span_ctx:
|
||||
_span_ctx.__exit__(None, None, None)
|
||||
_lf_ctx.__exit__(None, None, None)
|
||||
if lf:
|
||||
lf.flush()
|
||||
|
||||
@@ -714,8 +762,9 @@ async def _run_single_agent_stream(
|
||||
agent_name: str = "agent",
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
trace_id = _trace_id_from_context(context)
|
||||
session_id = _session_id_from_context(context)
|
||||
lf = get_langfuse()
|
||||
llm = get_llm()
|
||||
llm = get_agent_llm(agent_name)
|
||||
tools = _all_tools_for_user(user_id, trace_id)
|
||||
model_context = _context_for_model(context)
|
||||
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
||||
@@ -735,6 +784,9 @@ async def _run_single_agent_stream(
|
||||
collected: list[dict[str, Any]] = []
|
||||
set_tool_result_collector(collected)
|
||||
|
||||
_lf_ctx = langfuse_context(user_id=user_id, session_id=session_id)
|
||||
_lf_ctx.__enter__()
|
||||
|
||||
_span_ctx = (
|
||||
lf.start_as_current_observation(
|
||||
as_type="span",
|
||||
@@ -753,7 +805,7 @@ async def _run_single_agent_stream(
|
||||
lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name=f"{agent_name}-llm",
|
||||
model=settings.LLM_MODEL,
|
||||
model=model_for_agent(agent_name),
|
||||
prompt=langfuse_prompt,
|
||||
input=messages,
|
||||
)
|
||||
@@ -762,7 +814,7 @@ async def _run_single_agent_stream(
|
||||
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
if _gen_ctx:
|
||||
_gen.update(output=_as_text(response.content), usage=extract_usage(response))
|
||||
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||
_gen_ctx.__exit__(None, None, None)
|
||||
|
||||
messages.append(response)
|
||||
@@ -842,6 +894,7 @@ async def _run_single_agent_stream(
|
||||
clear_tool_result_collector()
|
||||
if _span_ctx:
|
||||
_span_ctx.__exit__(None, None, None)
|
||||
_lf_ctx.__exit__(None, None, None)
|
||||
if lf:
|
||||
lf.flush()
|
||||
|
||||
@@ -851,6 +904,7 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||
"home_system", _HOME_SYSTEM_PROMPT
|
||||
)
|
||||
system_prompt += _language_instruction(context)
|
||||
response = await _run_single_agent(
|
||||
user_id=user_id,
|
||||
system_prompt=system_prompt,
|
||||
@@ -868,6 +922,7 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t
|
||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||
"floating_system", _FLOATING_SYSTEM_PROMPT
|
||||
)
|
||||
system_prompt += _language_instruction(context)
|
||||
response = await _run_single_agent(
|
||||
user_id=user_id,
|
||||
system_prompt=system_prompt,
|
||||
@@ -891,6 +946,7 @@ async def run_home_stream(
|
||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||
"home_system", _HOME_SYSTEM_PROMPT
|
||||
)
|
||||
system_prompt += _language_instruction(context)
|
||||
text_chunks: list[str] = []
|
||||
async for event in _run_single_agent_stream(
|
||||
user_id=user_id,
|
||||
@@ -923,6 +979,7 @@ async def run_floating_stream(
|
||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||
"floating_system", _FLOATING_SYSTEM_PROMPT
|
||||
)
|
||||
system_prompt += _language_instruction(context)
|
||||
sanitizer = _FloatingStreamSanitizer()
|
||||
emitted_sanitized = False
|
||||
raw_chunks: list[str] = []
|
||||
|
||||
@@ -39,8 +39,10 @@ Linking a prompt to a generation::
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Any
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -145,3 +147,44 @@ def extract_usage(response: Any) -> dict[str, int]:
|
||||
"output": int(meta.get("output_tokens", 0)),
|
||||
"total": int(meta.get("total_tokens", 0)),
|
||||
}
|
||||
|
||||
|
||||
def hash_user_id(user_id: str) -> str:
|
||||
"""Return a SHA-256 hash of *user_id* for use as Langfuse ``user_id``.
|
||||
|
||||
This avoids sending raw database UUIDs to external observability services
|
||||
while still providing a stable, deterministic identifier for per-user
|
||||
metrics in the Langfuse dashboard.
|
||||
"""
|
||||
return hashlib.sha256(user_id.encode()).hexdigest()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def langfuse_context(
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Propagate ``user_id`` (hashed) and ``session_id`` to all Langfuse observations.
|
||||
|
||||
No-op when Langfuse is not configured or parameters are empty.
|
||||
"""
|
||||
lf = get_langfuse()
|
||||
if lf is None or (not user_id and not session_id):
|
||||
yield
|
||||
return
|
||||
|
||||
try:
|
||||
from langfuse import propagate_attributes
|
||||
except ImportError:
|
||||
logger.debug("langfuse: propagate_attributes not available — skipping context")
|
||||
yield
|
||||
return
|
||||
|
||||
attrs: dict[str, str] = {}
|
||||
if user_id:
|
||||
attrs["user_id"] = hash_user_id(user_id)
|
||||
if session_id:
|
||||
attrs["session_id"] = session_id
|
||||
|
||||
with propagate_attributes(**attrs):
|
||||
yield
|
||||
|
||||
@@ -19,6 +19,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
import litellm
|
||||
@@ -95,6 +96,35 @@ def get_llm(
|
||||
)
|
||||
|
||||
|
||||
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
||||
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
|
||||
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
|
||||
"floating-agent": lambda: settings.LLM_MODEL_FLOATING_AGENT or settings.LLM_MODEL,
|
||||
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
||||
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
||||
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
||||
}
|
||||
|
||||
|
||||
def model_for_agent(agent_name: str) -> str:
|
||||
"""Return the resolved model string for *agent_name* (for Langfuse tracking)."""
|
||||
return _AGENT_MODEL_SETTINGS.get(agent_name, lambda: settings.LLM_MODEL)()
|
||||
|
||||
|
||||
def get_agent_llm(
|
||||
agent_name: str,
|
||||
*,
|
||||
temperature: float = 0,
|
||||
) -> ChatOpenAI | ChatLiteLLM:
|
||||
"""Return an LLM configured for *agent_name*, respecting per-agent overrides.
|
||||
|
||||
Falls back to ``settings.LLM_MODEL`` for unknown agent names or when the
|
||||
per-agent override is left empty in ``.env``.
|
||||
"""
|
||||
model = model_for_agent(agent_name)
|
||||
return get_llm(model=model, temperature=temperature)
|
||||
|
||||
|
||||
async def embed(text: str) -> list[float]:
|
||||
"""Return an embedding vector for *text*.
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -69,7 +69,8 @@ class User(Base):
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
password_hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
avatar_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
||||
@@ -78,6 +79,9 @@ class User(Base):
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
onboarding_completed_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, default=None
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
@@ -88,6 +92,9 @@ class User(Base):
|
||||
subscription: Mapped[Subscription | None] = relationship(
|
||||
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
||||
)
|
||||
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
||||
back_populates="user", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class RefreshToken(Base):
|
||||
@@ -108,6 +115,25 @@ class RefreshToken(Base):
|
||||
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
||||
|
||||
|
||||
class OAuthAccount(Base):
|
||||
__tablename__ = "oauth_accounts"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
provider: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
provider_user_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
provider_email: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship(back_populates="oauth_accounts")
|
||||
|
||||
|
||||
class Subscription(Base):
|
||||
__tablename__ = "subscriptions"
|
||||
|
||||
|
||||
@@ -30,6 +30,16 @@ class UserProfile(BaseModel):
|
||||
name: str | None = None
|
||||
surname: str | None = None
|
||||
tier: BillingTier
|
||||
avatar_url: str | None = None
|
||||
has_password: bool = True
|
||||
onboarding_completed_at: int | None = None # epoch ms, null = not onboarded
|
||||
memory: dict[str, str] = Field(default_factory=dict) # decrypted core memory k/v
|
||||
|
||||
|
||||
class OAuthAccountInfo(BaseModel):
|
||||
provider: str
|
||||
provider_email: str | None = None
|
||||
created_at: int # epoch ms
|
||||
|
||||
|
||||
# ── Chat ─────────────────────────────────────────────────────────────
|
||||
@@ -236,10 +246,11 @@ class AgentTriggerRequest(BaseModel):
|
||||
device_id: str = Field(default="")
|
||||
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
|
||||
what_to_extract: list[str] = Field(min_length=1)
|
||||
actions_by_type: dict[str, list[str]] | None = None
|
||||
batch_interval: str = Field(min_length=1)
|
||||
custom_agent_prompt: str = Field(min_length=1)
|
||||
custom_agent_prompt: str | None = None
|
||||
agent_config: dict | None = None
|
||||
active_agents: int = Field(ge=0, default=0)
|
||||
last_run_at: int | None = None # epoch ms from FE — enables incremental scanning
|
||||
|
||||
|
||||
# ── Agent Run Log ─────────────────────────────────────────────────────
|
||||
|
||||
@@ -28,7 +28,6 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.core.agent_runner import (
|
||||
_extract_items_from_content,
|
||||
@@ -597,7 +596,7 @@ async def test_run_cloud_agent_provider_fetch_error():
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_refreshed_token_persisted():
|
||||
"""When the provider refreshes its token, the new ciphertext is written to DB."""
|
||||
from app.integrations import EmailMessage, encrypt_token
|
||||
from app.integrations import encrypt_token
|
||||
from cryptography.fernet import Fernet as _Fernet
|
||||
|
||||
fernet_key = _Fernet.generate_key().decode()
|
||||
@@ -791,7 +790,6 @@ async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
||||
json={
|
||||
"directory": "/home/user/docs",
|
||||
"what_to_extract": ["task", "note"],
|
||||
"actions_by_type": {"task": ["add", "update"], "note": ["add"]},
|
||||
"batch_interval": "0 */6 * * *",
|
||||
"custom_agent_prompt": "Extract tasks and notes.",
|
||||
"active_agents": 0,
|
||||
|
||||
@@ -40,7 +40,6 @@ from app.core.agent_runner import (
|
||||
_format_projects,
|
||||
_get_extraction_rules,
|
||||
_get_no_match_behavior,
|
||||
_is_overdue,
|
||||
run_local_agent,
|
||||
)
|
||||
from app.core.device_manager import DeviceConnectionManager
|
||||
|
||||
@@ -21,7 +21,6 @@ import time
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Tests for auth routes: register, login, refresh, me.
|
||||
"""Tests for auth routes: register, login, refresh, me, OAuth social login.
|
||||
|
||||
Exercises the full auth lifecycle through the FastAPI TestClient against the
|
||||
in-memory SQLite test database seeded by ``conftest.py``.
|
||||
@@ -7,9 +7,11 @@ in-memory SQLite test database seeded by ``conftest.py``.
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from jose import jwt
|
||||
|
||||
from app.auth.oauth_providers import GoogleOAuthProvider, OAuthUserInfo
|
||||
from app.config.settings import settings
|
||||
from tests.conftest import auth_header, TEST_USER_IDS
|
||||
|
||||
@@ -204,3 +206,153 @@ class TestMe:
|
||||
token = jwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# ── TestOAuth ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestOAuth:
|
||||
"""GET /auth/oauth/google/authorize and POST /auth/oauth/google/callback."""
|
||||
|
||||
FAKE_PROVIDER_USER_ID = "google-sub-12345"
|
||||
FAKE_EMAIL = "oauth@example.com"
|
||||
FAKE_AVATAR = "https://lh3.googleusercontent.com/photo.jpg"
|
||||
|
||||
def _patch_google(self, monkeypatch) -> None:
|
||||
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_ID", "fake-client-id")
|
||||
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_SECRET", "fake-client-secret")
|
||||
|
||||
def _userinfo(
|
||||
self,
|
||||
email: str | None = None,
|
||||
email_verified: bool = True,
|
||||
) -> OAuthUserInfo:
|
||||
return OAuthUserInfo(
|
||||
provider_user_id=self.FAKE_PROVIDER_USER_ID,
|
||||
email=email or self.FAKE_EMAIL,
|
||||
email_verified=email_verified,
|
||||
avatar_url=self.FAKE_AVATAR,
|
||||
name="OAuth User",
|
||||
)
|
||||
|
||||
def _authorize(self, client) -> str:
|
||||
"""Call /authorize and return the fresh state token."""
|
||||
resp = client.get("/api/v1/auth/oauth/google/authorize")
|
||||
assert resp.status_code == 200
|
||||
return resp.json()["state"]
|
||||
|
||||
def _callback(self, client, state: str, userinfo: OAuthUserInfo):
|
||||
"""POST /callback with mocked provider exchange_code + get_userinfo."""
|
||||
with (
|
||||
patch.object(
|
||||
GoogleOAuthProvider,
|
||||
"exchange_code",
|
||||
new=AsyncMock(return_value={"access_token": "google-access-tok"}),
|
||||
),
|
||||
patch.object(
|
||||
GoogleOAuthProvider,
|
||||
"get_userinfo",
|
||||
new=AsyncMock(return_value=userinfo),
|
||||
),
|
||||
):
|
||||
return client.post(
|
||||
"/api/v1/auth/oauth/google/callback",
|
||||
json={"code": "auth-code", "state": state},
|
||||
)
|
||||
|
||||
def _decode_sub(self, access_token: str) -> str:
|
||||
return jwt.decode(
|
||||
access_token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)["sub"]
|
||||
|
||||
# -- authorize --
|
||||
|
||||
def test_authorize_returns_url_and_state(self, client, monkeypatch) -> None:
|
||||
self._patch_google(monkeypatch)
|
||||
resp = client.get("/api/v1/auth/oauth/google/authorize")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "url" in data and "state" in data
|
||||
assert "accounts.google.com" in data["url"]
|
||||
assert len(data["state"]) > 0
|
||||
|
||||
def test_authorize_unconfigured_returns_503(self, client, monkeypatch) -> None:
|
||||
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_ID", "")
|
||||
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_SECRET", "")
|
||||
resp = client.get("/api/v1/auth/oauth/google/authorize")
|
||||
assert resp.status_code == 503
|
||||
|
||||
# -- callback --
|
||||
|
||||
def test_callback_state_mismatch_returns_401(self, client, monkeypatch) -> None:
|
||||
self._patch_google(monkeypatch)
|
||||
resp = client.post(
|
||||
"/api/v1/auth/oauth/google/callback",
|
||||
json={"code": "code", "state": "not-a-real-state"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
def test_callback_creates_new_user(self, client, monkeypatch) -> None:
|
||||
"""First-time Google login creates a new user and returns valid tokens."""
|
||||
self._patch_google(monkeypatch)
|
||||
state = self._authorize(client)
|
||||
resp = self._callback(client, state, self._userinfo())
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "access_token" in data and "refresh_token" in data
|
||||
payload = jwt.decode(
|
||||
data["access_token"], settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
assert payload["email"] == self.FAKE_EMAIL
|
||||
|
||||
def test_callback_existing_oauth_link_logs_in(self, client, monkeypatch) -> None:
|
||||
"""Second Google login with the same account re-uses the existing user."""
|
||||
self._patch_google(monkeypatch)
|
||||
userinfo = self._userinfo()
|
||||
|
||||
# First login — creates user + oauth_accounts row
|
||||
resp1 = self._callback(client, self._authorize(client), userinfo)
|
||||
assert resp1.status_code == 200
|
||||
sub1 = self._decode_sub(resp1.json()["access_token"])
|
||||
|
||||
# Second login — finds existing oauth_accounts row → same user
|
||||
resp2 = self._callback(client, self._authorize(client), userinfo)
|
||||
assert resp2.status_code == 200
|
||||
sub2 = self._decode_sub(resp2.json()["access_token"])
|
||||
|
||||
assert sub1 == sub2
|
||||
|
||||
def test_callback_email_match_links_account(self, client, monkeypatch) -> None:
|
||||
"""Verified Google email matching an existing password user links the accounts."""
|
||||
email = "link-target@example.com"
|
||||
reg_resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": email, "password": "TestPass123!"},
|
||||
)
|
||||
assert reg_resp.status_code == 201
|
||||
orig_sub = self._decode_sub(reg_resp.json()["access_token"])
|
||||
|
||||
self._patch_google(monkeypatch)
|
||||
state = self._authorize(client)
|
||||
resp = self._callback(client, state, self._userinfo(email=email, email_verified=True))
|
||||
|
||||
assert resp.status_code == 200
|
||||
oauth_sub = self._decode_sub(resp.json()["access_token"])
|
||||
# OAuth login must resolve to the same user as the original registration
|
||||
assert orig_sub == oauth_sub
|
||||
|
||||
def test_callback_unverified_email_conflict_returns_409(self, client, monkeypatch) -> None:
|
||||
"""Unverified Google email matching an existing account returns 409, not 500."""
|
||||
email = "conflict@example.com"
|
||||
reg_resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": email, "password": "TestPass123!"},
|
||||
)
|
||||
assert reg_resp.status_code == 201
|
||||
|
||||
self._patch_google(monkeypatch)
|
||||
state = self._authorize(client)
|
||||
resp = self._callback(client, state, self._userinfo(email=email, email_verified=False))
|
||||
|
||||
assert resp.status_code == 409
|
||||
|
||||
@@ -18,13 +18,12 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.core.device_manager import DeviceConnection, DeviceConnectionManager
|
||||
from app.core.device_manager import DeviceConnectionManager
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import AgentRunLog
|
||||
from tests.conftest import TEST_USER_IDS, auth_header, make_jwt
|
||||
from tests.conftest import TEST_USER_IDS, make_jwt
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
|
||||
@@ -40,11 +40,9 @@ Coverage:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.memory_middleware import MemoryMiddleware, _PROACTIVE_CONFIDENCE_THRESHOLD
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import (
|
||||
|
||||
@@ -7,10 +7,9 @@ column is stored as JSON in tests (SQLite-compatible).
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from __future__ import annotations
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from app.core.preprocessors import detect_content_type, preprocess
|
||||
|
||||
Reference in New Issue
Block a user