Implements check_folder_quota and add_token_usage in app/billing/quota.py with dialect-aware upsert (pg_insert on PostgreSQL, read-then-write on SQLite). Adds test_user_free/test_user_power fixtures and db alias to conftest.py. 6 new tests pass. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
140 lines
4.3 KiB
Python
140 lines
4.3 KiB
Python
"""Quota checks and atomic token-usage accounting for folder integration."""
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
|
|
from sqlalchemy import select, update
|
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.billing.tier_manager import TierManager
|
|
from app.models import MonthlyTokenUsage
|
|
from app.schemas import BillingTier
|
|
|
|
|
|
class QuotaExceeded(Exception):
|
|
"""Raised when a folder operation cannot proceed under the user's tier."""
|
|
|
|
def __init__(self, reason: str, message: str) -> None:
|
|
super().__init__(message)
|
|
self.reason = reason # "max_files" | "monthly_tokens"
|
|
|
|
|
|
@dataclass
|
|
class TokenUsageResult:
|
|
tokens_used: int
|
|
exhausted: bool
|
|
|
|
|
|
def _current_year_month() -> str:
|
|
return datetime.now(timezone.utc).strftime("%Y-%m")
|
|
|
|
|
|
_tier_manager = TierManager()
|
|
|
|
|
|
async def check_folder_quota(
|
|
*,
|
|
user_id: str,
|
|
tier: BillingTier,
|
|
estimated_files: int,
|
|
db: AsyncSession,
|
|
) -> None:
|
|
"""Raise QuotaExceeded if folder_max_files or folder_monthly_tokens
|
|
would be violated. -1 in either feature means unlimited."""
|
|
max_files = _tier_manager.get_feature_value(tier, "folder_max_files")
|
|
if max_files != -1 and estimated_files > max_files:
|
|
raise QuotaExceeded(
|
|
"max_files",
|
|
f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.",
|
|
)
|
|
|
|
cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
|
if cap == -1:
|
|
return
|
|
ym = _current_year_month()
|
|
row = (
|
|
await db.execute(
|
|
select(MonthlyTokenUsage).where(
|
|
MonthlyTokenUsage.user_id == user_id,
|
|
MonthlyTokenUsage.year_month == ym,
|
|
MonthlyTokenUsage.feature == "folder_index",
|
|
)
|
|
)
|
|
).scalar_one_or_none()
|
|
used = row.tokens_used if row else 0
|
|
if used >= cap:
|
|
raise QuotaExceeded(
|
|
"monthly_tokens",
|
|
f"Monthly token budget exhausted ({used}/{cap}); resets next month.",
|
|
)
|
|
|
|
|
|
async def add_token_usage(
|
|
*,
|
|
user_id: str,
|
|
feature: str,
|
|
tokens: int,
|
|
db: AsyncSession,
|
|
cap: int | None = None,
|
|
) -> TokenUsageResult:
|
|
"""Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature).
|
|
|
|
Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls
|
|
back to a read-then-write on other engines (e.g. aiosqlite in tests).
|
|
Returns post-update total and whether cap is exhausted.
|
|
"""
|
|
ym = _current_year_month()
|
|
|
|
# Detect dialect to choose between native upsert and portable fallback.
|
|
dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr]
|
|
|
|
if dialect_name == "postgresql":
|
|
# Native atomic upsert — production path.
|
|
stmt = (
|
|
pg_insert(MonthlyTokenUsage)
|
|
.values(
|
|
user_id=user_id,
|
|
year_month=ym,
|
|
feature=feature,
|
|
tokens_used=tokens,
|
|
)
|
|
.on_conflict_do_update(
|
|
index_elements=["user_id", "year_month", "feature"],
|
|
set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens},
|
|
)
|
|
.returning(MonthlyTokenUsage.tokens_used)
|
|
)
|
|
used: int = (await db.execute(stmt)).scalar_one()
|
|
await db.commit()
|
|
else:
|
|
# Portable fallback — used in tests (SQLite) and any non-PG engine.
|
|
row = (
|
|
await db.execute(
|
|
select(MonthlyTokenUsage).where(
|
|
MonthlyTokenUsage.user_id == user_id,
|
|
MonthlyTokenUsage.year_month == ym,
|
|
MonthlyTokenUsage.feature == feature,
|
|
)
|
|
)
|
|
).scalar_one_or_none()
|
|
|
|
if row is None:
|
|
row = MonthlyTokenUsage(
|
|
user_id=user_id,
|
|
year_month=ym,
|
|
feature=feature,
|
|
tokens_used=tokens,
|
|
)
|
|
db.add(row)
|
|
else:
|
|
row.tokens_used += tokens
|
|
|
|
await db.commit()
|
|
await db.refresh(row)
|
|
used = row.tokens_used
|
|
|
|
exhausted = cap is not None and cap != -1 and used >= cap
|
|
return TokenUsageResult(tokens_used=used, exhausted=exhausted)
|