Files
api/tests/test_auth.py
Roberto Musso c1a8ac7669 test: add TestOAuth suite for Google OAuth routes
6 tests covering the authorize and callback endpoints:
- authorize returns URL + state, 503 when unconfigured
- callback: state mismatch → 401, new user creation, existing OAuth
  link re-login (same user sub), email-match auto-linking to password user

Provider methods (exchange_code, get_userinfo) are mocked via AsyncMock
so tests run without hitting Google APIs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 13:42:11 +02:00

344 lines
13 KiB
Python

"""Tests for auth routes: register, login, refresh, me, OAuth social login.
Exercises the full auth lifecycle through the FastAPI TestClient against the
in-memory SQLite test database seeded by ``conftest.py``.
"""
from __future__ import annotations
import time
from unittest.mock import AsyncMock, patch
from jose import jwt
from app.auth.oauth_providers import GoogleOAuthProvider, OAuthUserInfo
from app.config.settings import settings
from tests.conftest import auth_header, TEST_USER_IDS
# ── TestRegister ──────────────────────────────────────────────────────
class TestRegister:
"""POST /api/v1/auth/register"""
def test_register_success(self, client) -> None:
resp = client.post(
"/api/v1/auth/register",
json={"email": "new@example.com", "password": "Str0ngP@ss!"},
)
assert resp.status_code == 201
data = resp.json()
assert "access_token" in data
assert "refresh_token" in data
assert "expires_at" in data
# expires_at should be a future millisecond timestamp
assert data["expires_at"] > int(time.time() * 1000)
def test_register_returns_valid_jwt(self, client) -> None:
resp = client.post(
"/api/v1/auth/register",
json={"email": "jwt-check@example.com", "password": "P@ss1234"},
)
assert resp.status_code == 201
token = resp.json()["access_token"]
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
assert payload["email"] == "jwt-check@example.com"
assert payload["tier"] == "free"
assert "sub" in payload
def test_register_duplicate_email(self, client) -> None:
client.post(
"/api/v1/auth/register",
json={"email": "dupe@example.com", "password": "Pass1234"},
)
resp = client.post(
"/api/v1/auth/register",
json={"email": "dupe@example.com", "password": "Pass5678"},
)
assert resp.status_code == 409
def test_register_missing_password(self, client) -> None:
resp = client.post(
"/api/v1/auth/register",
json={"email": "no-pass@example.com"},
)
assert resp.status_code == 422
def test_register_missing_email(self, client) -> None:
resp = client.post(
"/api/v1/auth/register",
json={"password": "OnlyPass"},
)
assert resp.status_code == 422
# ── TestLogin ─────────────────────────────────────────────────────────
class TestLogin:
"""POST /api/v1/auth/login"""
def _register(self, client, email="login@example.com", password="MyP@ss123"):
client.post(
"/api/v1/auth/register",
json={"email": email, "password": password},
)
def test_login_success(self, client) -> None:
self._register(client)
resp = client.post(
"/api/v1/auth/login",
json={"email": "login@example.com", "password": "MyP@ss123"},
)
assert resp.status_code == 200
data = resp.json()
assert "access_token" in data
assert "refresh_token" in data
assert "expires_at" in data
def test_login_wrong_password(self, client) -> None:
self._register(client)
resp = client.post(
"/api/v1/auth/login",
json={"email": "login@example.com", "password": "WrongPass!"},
)
assert resp.status_code == 401
def test_login_unknown_email(self, client) -> None:
resp = client.post(
"/api/v1/auth/login",
json={"email": "ghost@example.com", "password": "Whatever"},
)
assert resp.status_code == 401
# ── TestRefresh ───────────────────────────────────────────────────────
class TestRefresh:
"""POST /api/v1/auth/refresh"""
def _register_and_get_tokens(self, client, email="refresh@example.com"):
resp = client.post(
"/api/v1/auth/register",
json={"email": email, "password": "RefPass123!"},
)
return resp.json()
def test_refresh_returns_new_tokens(self, client) -> None:
tokens = self._register_and_get_tokens(client)
resp = client.post(
"/api/v1/auth/refresh",
json={"refresh_token": tokens["refresh_token"]},
)
assert resp.status_code == 200
data = resp.json()
assert "access_token" in data
assert "refresh_token" in data
# New refresh token should differ from old one (rotation)
assert data["refresh_token"] != tokens["refresh_token"]
def test_refresh_old_token_rejected(self, client) -> None:
"""After rotation, the original refresh token must be rejected."""
tokens = self._register_and_get_tokens(client, email="rotate@example.com")
old_rt = tokens["refresh_token"]
# First refresh succeeds and rotates the token
client.post("/api/v1/auth/refresh", json={"refresh_token": old_rt})
# Second attempt with the old token must fail
resp = client.post("/api/v1/auth/refresh", json={"refresh_token": old_rt})
assert resp.status_code == 401
def test_refresh_bogus_token(self, client) -> None:
resp = client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "not-a-real-token"},
)
assert resp.status_code == 401
# ── TestMe ────────────────────────────────────────────────────────────
class TestMe:
"""GET /api/v1/auth/me"""
def test_me_with_valid_jwt(self, client) -> None:
resp = client.get("/api/v1/auth/me", headers=auth_header("power"))
assert resp.status_code == 200
data = resp.json()
assert data["id"] == TEST_USER_IDS["power"]
assert data["email"] == "power@test.com"
assert data["tier"] == "power"
def test_me_returns_correct_tier(self, client) -> None:
"""Tier comes from the live subscription row, not the JWT claim."""
resp = client.get("/api/v1/auth/me", headers=auth_header("free"))
assert resp.json()["tier"] == "free"
def test_me_missing_token(self, client) -> None:
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
def test_me_expired_token(self, client) -> None:
"""A JWT with ``exp`` in the past must be rejected."""
payload = {
"sub": TEST_USER_IDS["power"],
"email": "power@test.com",
"tier": "power",
"exp": int(time.time()) - 3600, # 1 hour ago
"iat": int(time.time()) - 7200,
}
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == 401
def test_me_invalid_signature(self, client) -> None:
payload = {
"sub": TEST_USER_IDS["power"],
"email": "power@test.com",
"tier": "power",
"exp": int(time.time()) + 3600,
"iat": int(time.time()),
}
token = jwt.encode(payload, "wrong-secret", algorithm="HS256")
resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == 401
# ── TestOAuth ─────────────────────────────────────────────────────────
class TestOAuth:
"""GET /auth/oauth/google/authorize and POST /auth/oauth/google/callback."""
FAKE_PROVIDER_USER_ID = "google-sub-12345"
FAKE_EMAIL = "oauth@example.com"
FAKE_AVATAR = "https://lh3.googleusercontent.com/photo.jpg"
def _patch_google(self, monkeypatch) -> None:
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_ID", "fake-client-id")
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_SECRET", "fake-client-secret")
def _userinfo(
self,
email: str | None = None,
email_verified: bool = True,
) -> OAuthUserInfo:
return OAuthUserInfo(
provider_user_id=self.FAKE_PROVIDER_USER_ID,
email=email or self.FAKE_EMAIL,
email_verified=email_verified,
avatar_url=self.FAKE_AVATAR,
name="OAuth User",
)
def _authorize(self, client) -> str:
"""Call /authorize and return the fresh state token."""
resp = client.get("/api/v1/auth/oauth/google/authorize")
assert resp.status_code == 200
return resp.json()["state"]
def _callback(self, client, state: str, userinfo: OAuthUserInfo):
"""POST /callback with mocked provider exchange_code + get_userinfo."""
with (
patch.object(
GoogleOAuthProvider,
"exchange_code",
new=AsyncMock(return_value={"access_token": "google-access-tok"}),
),
patch.object(
GoogleOAuthProvider,
"get_userinfo",
new=AsyncMock(return_value=userinfo),
),
):
return client.post(
"/api/v1/auth/oauth/google/callback",
json={"code": "auth-code", "state": state},
)
def _decode_sub(self, access_token: str) -> str:
return jwt.decode(
access_token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)["sub"]
# -- authorize --
def test_authorize_returns_url_and_state(self, client, monkeypatch) -> None:
self._patch_google(monkeypatch)
resp = client.get("/api/v1/auth/oauth/google/authorize")
assert resp.status_code == 200
data = resp.json()
assert "url" in data and "state" in data
assert "accounts.google.com" in data["url"]
assert len(data["state"]) > 0
def test_authorize_unconfigured_returns_503(self, client, monkeypatch) -> None:
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_ID", "")
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_SECRET", "")
resp = client.get("/api/v1/auth/oauth/google/authorize")
assert resp.status_code == 503
# -- callback --
def test_callback_state_mismatch_returns_401(self, client, monkeypatch) -> None:
self._patch_google(monkeypatch)
resp = client.post(
"/api/v1/auth/oauth/google/callback",
json={"code": "code", "state": "not-a-real-state"},
)
assert resp.status_code == 401
def test_callback_creates_new_user(self, client, monkeypatch) -> None:
"""First-time Google login creates a new user and returns valid tokens."""
self._patch_google(monkeypatch)
state = self._authorize(client)
resp = self._callback(client, state, self._userinfo())
assert resp.status_code == 200
data = resp.json()
assert "access_token" in data and "refresh_token" in data
payload = jwt.decode(
data["access_token"], settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
assert payload["email"] == self.FAKE_EMAIL
def test_callback_existing_oauth_link_logs_in(self, client, monkeypatch) -> None:
"""Second Google login with the same account re-uses the existing user."""
self._patch_google(monkeypatch)
userinfo = self._userinfo()
# First login — creates user + oauth_accounts row
resp1 = self._callback(client, self._authorize(client), userinfo)
assert resp1.status_code == 200
sub1 = self._decode_sub(resp1.json()["access_token"])
# Second login — finds existing oauth_accounts row → same user
resp2 = self._callback(client, self._authorize(client), userinfo)
assert resp2.status_code == 200
sub2 = self._decode_sub(resp2.json()["access_token"])
assert sub1 == sub2
def test_callback_email_match_links_account(self, client, monkeypatch) -> None:
"""Verified Google email matching an existing password user links the accounts."""
email = "link-target@example.com"
reg_resp = client.post(
"/api/v1/auth/register",
json={"email": email, "password": "TestPass123!"},
)
assert reg_resp.status_code == 201
orig_sub = self._decode_sub(reg_resp.json()["access_token"])
self._patch_google(monkeypatch)
state = self._authorize(client)
resp = self._callback(client, state, self._userinfo(email=email, email_verified=True))
assert resp.status_code == 200
oauth_sub = self._decode_sub(resp.json()["access_token"])
# OAuth login must resolve to the same user as the original registration
assert orig_sub == oauth_sub