From a98e99f7a250530dc6906fad11d8375b7cfa6c54 Mon Sep 17 00:00:00 2001 From: Roberto Date: Tue, 12 May 2026 08:23:22 +0200 Subject: [PATCH] feat(api): folder quota helpers with atomic token usage Implements check_folder_quota and add_token_usage in app/billing/quota.py with dialect-aware upsert (pg_insert on PostgreSQL, read-then-write on SQLite). Adds test_user_free/test_user_power fixtures and db alias to conftest.py. 6 new tests pass. Co-Authored-By: Claude Sonnet 4.6 --- app/billing/quota.py | 139 +++++++++++++++++++++++++++++++++++++ tests/conftest.py | 28 ++++++++ tests/test_folder_quota.py | 73 +++++++++++++++++++ 3 files changed, 240 insertions(+) create mode 100644 app/billing/quota.py create mode 100644 tests/test_folder_quota.py diff --git a/app/billing/quota.py b/app/billing/quota.py new file mode 100644 index 0000000..f22767c --- /dev/null +++ b/app/billing/quota.py @@ -0,0 +1,139 @@ +"""Quota checks and atomic token-usage accounting for folder integration.""" +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone + +from sqlalchemy import select, update +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.ext.asyncio import AsyncSession + +from app.billing.tier_manager import TierManager +from app.models import MonthlyTokenUsage +from app.schemas import BillingTier + + +class QuotaExceeded(Exception): + """Raised when a folder operation cannot proceed under the user's tier.""" + + def __init__(self, reason: str, message: str) -> None: + super().__init__(message) + self.reason = reason # "max_files" | "monthly_tokens" + + +@dataclass +class TokenUsageResult: + tokens_used: int + exhausted: bool + + +def _current_year_month() -> str: + return datetime.now(timezone.utc).strftime("%Y-%m") + + +_tier_manager = TierManager() + + +async def check_folder_quota( + *, + user_id: str, + tier: BillingTier, + estimated_files: int, + db: AsyncSession, +) -> None: + """Raise QuotaExceeded if folder_max_files or folder_monthly_tokens + would be violated. -1 in either feature means unlimited.""" + max_files = _tier_manager.get_feature_value(tier, "folder_max_files") + if max_files != -1 and estimated_files > max_files: + raise QuotaExceeded( + "max_files", + f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.", + ) + + cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens") + if cap == -1: + return + ym = _current_year_month() + row = ( + await db.execute( + select(MonthlyTokenUsage).where( + MonthlyTokenUsage.user_id == user_id, + MonthlyTokenUsage.year_month == ym, + MonthlyTokenUsage.feature == "folder_index", + ) + ) + ).scalar_one_or_none() + used = row.tokens_used if row else 0 + if used >= cap: + raise QuotaExceeded( + "monthly_tokens", + f"Monthly token budget exhausted ({used}/{cap}); resets next month.", + ) + + +async def add_token_usage( + *, + user_id: str, + feature: str, + tokens: int, + db: AsyncSession, + cap: int | None = None, +) -> TokenUsageResult: + """Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature). + + Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls + back to a read-then-write on other engines (e.g. aiosqlite in tests). + Returns post-update total and whether cap is exhausted. + """ + ym = _current_year_month() + + # Detect dialect to choose between native upsert and portable fallback. + dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr] + + if dialect_name == "postgresql": + # Native atomic upsert — production path. + stmt = ( + pg_insert(MonthlyTokenUsage) + .values( + user_id=user_id, + year_month=ym, + feature=feature, + tokens_used=tokens, + ) + .on_conflict_do_update( + index_elements=["user_id", "year_month", "feature"], + set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens}, + ) + .returning(MonthlyTokenUsage.tokens_used) + ) + used: int = (await db.execute(stmt)).scalar_one() + await db.commit() + else: + # Portable fallback — used in tests (SQLite) and any non-PG engine. + row = ( + await db.execute( + select(MonthlyTokenUsage).where( + MonthlyTokenUsage.user_id == user_id, + MonthlyTokenUsage.year_month == ym, + MonthlyTokenUsage.feature == feature, + ) + ) + ).scalar_one_or_none() + + if row is None: + row = MonthlyTokenUsage( + user_id=user_id, + year_month=ym, + feature=feature, + tokens_used=tokens, + ) + db.add(row) + else: + row.tokens_used += tokens + + await db.commit() + await db.refresh(row) + used = row.tokens_used + + exhausted = cap is not None and cap != -1 and used >= cap + return TokenUsageResult(tokens_used=used, exhausted=exhausted) diff --git a/tests/conftest.py b/tests/conftest.py index fdef3ad..b82b4f5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,8 @@ from jose import jwt from sqlalchemy import StaticPool, event from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy import select + from app.config.settings import settings from app.db import Base, get_session from app.main import app @@ -134,6 +136,32 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"} +# ── Convenience aliases and per-tier user fixtures ──────────────────── + +@pytest_asyncio.fixture +async def db(db_session: AsyncSession) -> AsyncSession: + """Alias for db_session — used by folder quota tests.""" + return db_session + + +@pytest_asyncio.fixture +async def test_user_free(db_session: AsyncSession): + """Return the seeded free-tier User row.""" + result = await db_session.execute( + select(User).where(User.id == TEST_USER_IDS["free"]) + ) + return result.scalar_one() + + +@pytest_asyncio.fixture +async def test_user_power(db_session: AsyncSession): + """Return the seeded power-tier User row.""" + result = await db_session.execute( + select(User).where(User.id == TEST_USER_IDS["power"]) + ) + return result.scalar_one() + + # ── CLI options ─────────────────────────────────────────────────────── def pytest_addoption(parser): diff --git a/tests/test_folder_quota.py b/tests/test_folder_quota.py new file mode 100644 index 0000000..1b61716 --- /dev/null +++ b/tests/test_folder_quota.py @@ -0,0 +1,73 @@ +"""Folder quota helpers.""" +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest +from sqlalchemy import select + +from app.billing.quota import ( + check_folder_quota, + add_token_usage, + QuotaExceeded, +) +from app.models import MonthlyTokenUsage + + +pytestmark = pytest.mark.asyncio + + +async def test_check_folder_quota_free_rejects_above_file_cap(db, test_user_free): + with pytest.raises(QuotaExceeded) as exc: + await check_folder_quota( + user_id=test_user_free.id, tier="free", estimated_files=500, db=db + ) + assert exc.value.reason == "max_files" + + +async def test_check_folder_quota_free_passes_under_cap(db, test_user_free): + # No raise + await check_folder_quota( + user_id=test_user_free.id, tier="free", estimated_files=50, db=db + ) + + +async def test_check_folder_quota_rejects_when_monthly_exhausted(db, test_user_free): + ym = datetime.now(timezone.utc).strftime("%Y-%m") + db.add(MonthlyTokenUsage( + user_id=test_user_free.id, year_month=ym, feature="folder_index", tokens_used=100_000 + )) + await db.commit() + with pytest.raises(QuotaExceeded) as exc: + await check_folder_quota( + user_id=test_user_free.id, tier="free", estimated_files=10, db=db + ) + assert exc.value.reason == "monthly_tokens" + + +async def test_check_folder_quota_power_unlimited(db, test_user_power): + await check_folder_quota( + user_id=test_user_power.id, tier="power", estimated_files=999_999, db=db + ) + + +async def test_add_token_usage_atomic_increment(db, test_user_free): + await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=1500, db=db) + await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=2500, db=db) + ym = datetime.now(timezone.utc).strftime("%Y-%m") + row = (await db.execute( + select(MonthlyTokenUsage).where( + MonthlyTokenUsage.user_id == test_user_free.id, + MonthlyTokenUsage.year_month == ym, + MonthlyTokenUsage.feature == "folder_index", + ) + )).scalar_one() + assert row.tokens_used == 4000 + + +async def test_add_token_usage_returns_exhausted_when_over_cap(db, test_user_free): + result = await add_token_usage( + user_id=test_user_free.id, feature="folder_index", tokens=150_000, db=db, cap=100_000 + ) + assert result.exhausted is True + assert result.tokens_used == 150_000