feat(api): folder quota helpers with atomic token usage

Implements check_folder_quota and add_token_usage in app/billing/quota.py
with dialect-aware upsert (pg_insert on PostgreSQL, read-then-write on SQLite).
Adds test_user_free/test_user_power fixtures and db alias to conftest.py.
6 new tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Roberto
2026-05-12 08:23:22 +02:00
parent a0ff285bcd
commit a98e99f7a2
3 changed files with 240 additions and 0 deletions

139
app/billing/quota.py Normal file
View File

@@ -0,0 +1,139 @@
"""Quota checks and atomic token-usage accounting for folder integration."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from sqlalchemy import select, update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from app.billing.tier_manager import TierManager
from app.models import MonthlyTokenUsage
from app.schemas import BillingTier
class QuotaExceeded(Exception):
"""Raised when a folder operation cannot proceed under the user's tier."""
def __init__(self, reason: str, message: str) -> None:
super().__init__(message)
self.reason = reason # "max_files" | "monthly_tokens"
@dataclass
class TokenUsageResult:
tokens_used: int
exhausted: bool
def _current_year_month() -> str:
return datetime.now(timezone.utc).strftime("%Y-%m")
_tier_manager = TierManager()
async def check_folder_quota(
*,
user_id: str,
tier: BillingTier,
estimated_files: int,
db: AsyncSession,
) -> None:
"""Raise QuotaExceeded if folder_max_files or folder_monthly_tokens
would be violated. -1 in either feature means unlimited."""
max_files = _tier_manager.get_feature_value(tier, "folder_max_files")
if max_files != -1 and estimated_files > max_files:
raise QuotaExceeded(
"max_files",
f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.",
)
cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens")
if cap == -1:
return
ym = _current_year_month()
row = (
await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == user_id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == "folder_index",
)
)
).scalar_one_or_none()
used = row.tokens_used if row else 0
if used >= cap:
raise QuotaExceeded(
"monthly_tokens",
f"Monthly token budget exhausted ({used}/{cap}); resets next month.",
)
async def add_token_usage(
*,
user_id: str,
feature: str,
tokens: int,
db: AsyncSession,
cap: int | None = None,
) -> TokenUsageResult:
"""Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature).
Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls
back to a read-then-write on other engines (e.g. aiosqlite in tests).
Returns post-update total and whether cap is exhausted.
"""
ym = _current_year_month()
# Detect dialect to choose between native upsert and portable fallback.
dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr]
if dialect_name == "postgresql":
# Native atomic upsert — production path.
stmt = (
pg_insert(MonthlyTokenUsage)
.values(
user_id=user_id,
year_month=ym,
feature=feature,
tokens_used=tokens,
)
.on_conflict_do_update(
index_elements=["user_id", "year_month", "feature"],
set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens},
)
.returning(MonthlyTokenUsage.tokens_used)
)
used: int = (await db.execute(stmt)).scalar_one()
await db.commit()
else:
# Portable fallback — used in tests (SQLite) and any non-PG engine.
row = (
await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == user_id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == feature,
)
)
).scalar_one_or_none()
if row is None:
row = MonthlyTokenUsage(
user_id=user_id,
year_month=ym,
feature=feature,
tokens_used=tokens,
)
db.add(row)
else:
row.tokens_used += tokens
await db.commit()
await db.refresh(row)
used = row.tokens_used
exhausted = cap is not None and cap != -1 and used >= cap
return TokenUsageResult(tokens_used=used, exhausted=exhausted)

View File

@@ -17,6 +17,8 @@ from jose import jwt
from sqlalchemy import StaticPool, event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy import select
from app.config.settings import settings
from app.db import Base, get_session
from app.main import app
@@ -134,6 +136,32 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
# ── Convenience aliases and per-tier user fixtures ────────────────────
@pytest_asyncio.fixture
async def db(db_session: AsyncSession) -> AsyncSession:
"""Alias for db_session — used by folder quota tests."""
return db_session
@pytest_asyncio.fixture
async def test_user_free(db_session: AsyncSession):
"""Return the seeded free-tier User row."""
result = await db_session.execute(
select(User).where(User.id == TEST_USER_IDS["free"])
)
return result.scalar_one()
@pytest_asyncio.fixture
async def test_user_power(db_session: AsyncSession):
"""Return the seeded power-tier User row."""
result = await db_session.execute(
select(User).where(User.id == TEST_USER_IDS["power"])
)
return result.scalar_one()
# ── CLI options ───────────────────────────────────────────────────────
def pytest_addoption(parser):

View File

@@ -0,0 +1,73 @@
"""Folder quota helpers."""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from sqlalchemy import select
from app.billing.quota import (
check_folder_quota,
add_token_usage,
QuotaExceeded,
)
from app.models import MonthlyTokenUsage
pytestmark = pytest.mark.asyncio
async def test_check_folder_quota_free_rejects_above_file_cap(db, test_user_free):
with pytest.raises(QuotaExceeded) as exc:
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=500, db=db
)
assert exc.value.reason == "max_files"
async def test_check_folder_quota_free_passes_under_cap(db, test_user_free):
# No raise
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=50, db=db
)
async def test_check_folder_quota_rejects_when_monthly_exhausted(db, test_user_free):
ym = datetime.now(timezone.utc).strftime("%Y-%m")
db.add(MonthlyTokenUsage(
user_id=test_user_free.id, year_month=ym, feature="folder_index", tokens_used=100_000
))
await db.commit()
with pytest.raises(QuotaExceeded) as exc:
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=10, db=db
)
assert exc.value.reason == "monthly_tokens"
async def test_check_folder_quota_power_unlimited(db, test_user_power):
await check_folder_quota(
user_id=test_user_power.id, tier="power", estimated_files=999_999, db=db
)
async def test_add_token_usage_atomic_increment(db, test_user_free):
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=1500, db=db)
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=2500, db=db)
ym = datetime.now(timezone.utc).strftime("%Y-%m")
row = (await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == test_user_free.id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == "folder_index",
)
)).scalar_one()
assert row.tokens_used == 4000
async def test_add_token_usage_returns_exhausted_when_over_cap(db, test_user_free):
result = await add_token_usage(
user_id=test_user_free.id, feature="folder_index", tokens=150_000, db=db, cap=100_000
)
assert result.exhausted is True
assert result.tokens_used == 150_000