Before: branch 3 of oauth_callback attempted to INSERT a user with a duplicate email → DB constraint violation → 500. After: if email_verified=False and the email already exists, raise 409 with a message directing the user to sign in with their password. Also adds test_callback_unverified_email_conflict_returns_409. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
359 lines
14 KiB
Python
359 lines
14 KiB
Python
"""Tests for auth routes: register, login, refresh, me, OAuth social login.
|
|
|
|
Exercises the full auth lifecycle through the FastAPI TestClient against the
|
|
in-memory SQLite test database seeded by ``conftest.py``.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
from jose import jwt
|
|
|
|
from app.auth.oauth_providers import GoogleOAuthProvider, OAuthUserInfo
|
|
from app.config.settings import settings
|
|
from tests.conftest import auth_header, TEST_USER_IDS
|
|
|
|
|
|
# ── TestRegister ──────────────────────────────────────────────────────
|
|
|
|
|
|
class TestRegister:
|
|
"""POST /api/v1/auth/register"""
|
|
|
|
def test_register_success(self, client) -> None:
|
|
resp = client.post(
|
|
"/api/v1/auth/register",
|
|
json={"email": "new@example.com", "password": "Str0ngP@ss!"},
|
|
)
|
|
assert resp.status_code == 201
|
|
data = resp.json()
|
|
assert "access_token" in data
|
|
assert "refresh_token" in data
|
|
assert "expires_at" in data
|
|
# expires_at should be a future millisecond timestamp
|
|
assert data["expires_at"] > int(time.time() * 1000)
|
|
|
|
def test_register_returns_valid_jwt(self, client) -> None:
|
|
resp = client.post(
|
|
"/api/v1/auth/register",
|
|
json={"email": "jwt-check@example.com", "password": "P@ss1234"},
|
|
)
|
|
assert resp.status_code == 201
|
|
token = resp.json()["access_token"]
|
|
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
|
|
assert payload["email"] == "jwt-check@example.com"
|
|
assert payload["tier"] == "free"
|
|
assert "sub" in payload
|
|
|
|
def test_register_duplicate_email(self, client) -> None:
|
|
client.post(
|
|
"/api/v1/auth/register",
|
|
json={"email": "dupe@example.com", "password": "Pass1234"},
|
|
)
|
|
resp = client.post(
|
|
"/api/v1/auth/register",
|
|
json={"email": "dupe@example.com", "password": "Pass5678"},
|
|
)
|
|
assert resp.status_code == 409
|
|
|
|
def test_register_missing_password(self, client) -> None:
|
|
resp = client.post(
|
|
"/api/v1/auth/register",
|
|
json={"email": "no-pass@example.com"},
|
|
)
|
|
assert resp.status_code == 422
|
|
|
|
def test_register_missing_email(self, client) -> None:
|
|
resp = client.post(
|
|
"/api/v1/auth/register",
|
|
json={"password": "OnlyPass"},
|
|
)
|
|
assert resp.status_code == 422
|
|
|
|
|
|
# ── TestLogin ─────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestLogin:
|
|
"""POST /api/v1/auth/login"""
|
|
|
|
def _register(self, client, email="login@example.com", password="MyP@ss123"):
|
|
client.post(
|
|
"/api/v1/auth/register",
|
|
json={"email": email, "password": password},
|
|
)
|
|
|
|
def test_login_success(self, client) -> None:
|
|
self._register(client)
|
|
resp = client.post(
|
|
"/api/v1/auth/login",
|
|
json={"email": "login@example.com", "password": "MyP@ss123"},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert "access_token" in data
|
|
assert "refresh_token" in data
|
|
assert "expires_at" in data
|
|
|
|
def test_login_wrong_password(self, client) -> None:
|
|
self._register(client)
|
|
resp = client.post(
|
|
"/api/v1/auth/login",
|
|
json={"email": "login@example.com", "password": "WrongPass!"},
|
|
)
|
|
assert resp.status_code == 401
|
|
|
|
def test_login_unknown_email(self, client) -> None:
|
|
resp = client.post(
|
|
"/api/v1/auth/login",
|
|
json={"email": "ghost@example.com", "password": "Whatever"},
|
|
)
|
|
assert resp.status_code == 401
|
|
|
|
|
|
# ── TestRefresh ───────────────────────────────────────────────────────
|
|
|
|
|
|
class TestRefresh:
|
|
"""POST /api/v1/auth/refresh"""
|
|
|
|
def _register_and_get_tokens(self, client, email="refresh@example.com"):
|
|
resp = client.post(
|
|
"/api/v1/auth/register",
|
|
json={"email": email, "password": "RefPass123!"},
|
|
)
|
|
return resp.json()
|
|
|
|
def test_refresh_returns_new_tokens(self, client) -> None:
|
|
tokens = self._register_and_get_tokens(client)
|
|
resp = client.post(
|
|
"/api/v1/auth/refresh",
|
|
json={"refresh_token": tokens["refresh_token"]},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert "access_token" in data
|
|
assert "refresh_token" in data
|
|
# New refresh token should differ from old one (rotation)
|
|
assert data["refresh_token"] != tokens["refresh_token"]
|
|
|
|
def test_refresh_old_token_rejected(self, client) -> None:
|
|
"""After rotation, the original refresh token must be rejected."""
|
|
tokens = self._register_and_get_tokens(client, email="rotate@example.com")
|
|
old_rt = tokens["refresh_token"]
|
|
|
|
# First refresh succeeds and rotates the token
|
|
client.post("/api/v1/auth/refresh", json={"refresh_token": old_rt})
|
|
|
|
# Second attempt with the old token must fail
|
|
resp = client.post("/api/v1/auth/refresh", json={"refresh_token": old_rt})
|
|
assert resp.status_code == 401
|
|
|
|
def test_refresh_bogus_token(self, client) -> None:
|
|
resp = client.post(
|
|
"/api/v1/auth/refresh",
|
|
json={"refresh_token": "not-a-real-token"},
|
|
)
|
|
assert resp.status_code == 401
|
|
|
|
|
|
# ── TestMe ────────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestMe:
|
|
"""GET /api/v1/auth/me"""
|
|
|
|
def test_me_with_valid_jwt(self, client) -> None:
|
|
resp = client.get("/api/v1/auth/me", headers=auth_header("power"))
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["id"] == TEST_USER_IDS["power"]
|
|
assert data["email"] == "power@test.com"
|
|
assert data["tier"] == "power"
|
|
|
|
def test_me_returns_correct_tier(self, client) -> None:
|
|
"""Tier comes from the live subscription row, not the JWT claim."""
|
|
resp = client.get("/api/v1/auth/me", headers=auth_header("free"))
|
|
assert resp.json()["tier"] == "free"
|
|
|
|
def test_me_missing_token(self, client) -> None:
|
|
resp = client.get("/api/v1/auth/me")
|
|
assert resp.status_code == 401
|
|
|
|
def test_me_expired_token(self, client) -> None:
|
|
"""A JWT with ``exp`` in the past must be rejected."""
|
|
payload = {
|
|
"sub": TEST_USER_IDS["power"],
|
|
"email": "power@test.com",
|
|
"tier": "power",
|
|
"exp": int(time.time()) - 3600, # 1 hour ago
|
|
"iat": int(time.time()) - 7200,
|
|
}
|
|
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
|
resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
|
|
assert resp.status_code == 401
|
|
|
|
def test_me_invalid_signature(self, client) -> None:
|
|
payload = {
|
|
"sub": TEST_USER_IDS["power"],
|
|
"email": "power@test.com",
|
|
"tier": "power",
|
|
"exp": int(time.time()) + 3600,
|
|
"iat": int(time.time()),
|
|
}
|
|
token = jwt.encode(payload, "wrong-secret", algorithm="HS256")
|
|
resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
|
|
assert resp.status_code == 401
|
|
|
|
|
|
# ── TestOAuth ─────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestOAuth:
|
|
"""GET /auth/oauth/google/authorize and POST /auth/oauth/google/callback."""
|
|
|
|
FAKE_PROVIDER_USER_ID = "google-sub-12345"
|
|
FAKE_EMAIL = "oauth@example.com"
|
|
FAKE_AVATAR = "https://lh3.googleusercontent.com/photo.jpg"
|
|
|
|
def _patch_google(self, monkeypatch) -> None:
|
|
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_ID", "fake-client-id")
|
|
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_SECRET", "fake-client-secret")
|
|
|
|
def _userinfo(
|
|
self,
|
|
email: str | None = None,
|
|
email_verified: bool = True,
|
|
) -> OAuthUserInfo:
|
|
return OAuthUserInfo(
|
|
provider_user_id=self.FAKE_PROVIDER_USER_ID,
|
|
email=email or self.FAKE_EMAIL,
|
|
email_verified=email_verified,
|
|
avatar_url=self.FAKE_AVATAR,
|
|
name="OAuth User",
|
|
)
|
|
|
|
def _authorize(self, client) -> str:
|
|
"""Call /authorize and return the fresh state token."""
|
|
resp = client.get("/api/v1/auth/oauth/google/authorize")
|
|
assert resp.status_code == 200
|
|
return resp.json()["state"]
|
|
|
|
def _callback(self, client, state: str, userinfo: OAuthUserInfo):
|
|
"""POST /callback with mocked provider exchange_code + get_userinfo."""
|
|
with (
|
|
patch.object(
|
|
GoogleOAuthProvider,
|
|
"exchange_code",
|
|
new=AsyncMock(return_value={"access_token": "google-access-tok"}),
|
|
),
|
|
patch.object(
|
|
GoogleOAuthProvider,
|
|
"get_userinfo",
|
|
new=AsyncMock(return_value=userinfo),
|
|
),
|
|
):
|
|
return client.post(
|
|
"/api/v1/auth/oauth/google/callback",
|
|
json={"code": "auth-code", "state": state},
|
|
)
|
|
|
|
def _decode_sub(self, access_token: str) -> str:
|
|
return jwt.decode(
|
|
access_token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
|
)["sub"]
|
|
|
|
# -- authorize --
|
|
|
|
def test_authorize_returns_url_and_state(self, client, monkeypatch) -> None:
|
|
self._patch_google(monkeypatch)
|
|
resp = client.get("/api/v1/auth/oauth/google/authorize")
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert "url" in data and "state" in data
|
|
assert "accounts.google.com" in data["url"]
|
|
assert len(data["state"]) > 0
|
|
|
|
def test_authorize_unconfigured_returns_503(self, client, monkeypatch) -> None:
|
|
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_ID", "")
|
|
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_SECRET", "")
|
|
resp = client.get("/api/v1/auth/oauth/google/authorize")
|
|
assert resp.status_code == 503
|
|
|
|
# -- callback --
|
|
|
|
def test_callback_state_mismatch_returns_401(self, client, monkeypatch) -> None:
|
|
self._patch_google(monkeypatch)
|
|
resp = client.post(
|
|
"/api/v1/auth/oauth/google/callback",
|
|
json={"code": "code", "state": "not-a-real-state"},
|
|
)
|
|
assert resp.status_code == 401
|
|
|
|
def test_callback_creates_new_user(self, client, monkeypatch) -> None:
|
|
"""First-time Google login creates a new user and returns valid tokens."""
|
|
self._patch_google(monkeypatch)
|
|
state = self._authorize(client)
|
|
resp = self._callback(client, state, self._userinfo())
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert "access_token" in data and "refresh_token" in data
|
|
payload = jwt.decode(
|
|
data["access_token"], settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
|
)
|
|
assert payload["email"] == self.FAKE_EMAIL
|
|
|
|
def test_callback_existing_oauth_link_logs_in(self, client, monkeypatch) -> None:
|
|
"""Second Google login with the same account re-uses the existing user."""
|
|
self._patch_google(monkeypatch)
|
|
userinfo = self._userinfo()
|
|
|
|
# First login — creates user + oauth_accounts row
|
|
resp1 = self._callback(client, self._authorize(client), userinfo)
|
|
assert resp1.status_code == 200
|
|
sub1 = self._decode_sub(resp1.json()["access_token"])
|
|
|
|
# Second login — finds existing oauth_accounts row → same user
|
|
resp2 = self._callback(client, self._authorize(client), userinfo)
|
|
assert resp2.status_code == 200
|
|
sub2 = self._decode_sub(resp2.json()["access_token"])
|
|
|
|
assert sub1 == sub2
|
|
|
|
def test_callback_email_match_links_account(self, client, monkeypatch) -> None:
|
|
"""Verified Google email matching an existing password user links the accounts."""
|
|
email = "link-target@example.com"
|
|
reg_resp = client.post(
|
|
"/api/v1/auth/register",
|
|
json={"email": email, "password": "TestPass123!"},
|
|
)
|
|
assert reg_resp.status_code == 201
|
|
orig_sub = self._decode_sub(reg_resp.json()["access_token"])
|
|
|
|
self._patch_google(monkeypatch)
|
|
state = self._authorize(client)
|
|
resp = self._callback(client, state, self._userinfo(email=email, email_verified=True))
|
|
|
|
assert resp.status_code == 200
|
|
oauth_sub = self._decode_sub(resp.json()["access_token"])
|
|
# OAuth login must resolve to the same user as the original registration
|
|
assert orig_sub == oauth_sub
|
|
|
|
def test_callback_unverified_email_conflict_returns_409(self, client, monkeypatch) -> None:
|
|
"""Unverified Google email matching an existing account returns 409, not 500."""
|
|
email = "conflict@example.com"
|
|
reg_resp = client.post(
|
|
"/api/v1/auth/register",
|
|
json={"email": email, "password": "TestPass123!"},
|
|
)
|
|
assert reg_resp.status_code == 201
|
|
|
|
self._patch_google(monkeypatch)
|
|
state = self._authorize(client)
|
|
resp = self._callback(client, state, self._userinfo(email=email, email_verified=False))
|
|
|
|
assert resp.status_code == 409
|