"""Quota checks and atomic token-usage accounting for folder integration.""" from __future__ import annotations from dataclasses import dataclass from datetime import datetime, timezone from sqlalchemy import select, update from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession from app.billing.tier_manager import TierManager from app.models import MonthlyTokenUsage from app.schemas import BillingTier class QuotaExceeded(Exception): """Raised when a folder operation cannot proceed under the user's tier.""" def __init__(self, reason: str, message: str) -> None: super().__init__(message) self.reason = reason # "max_files" | "monthly_tokens" @dataclass class TokenUsageResult: tokens_used: int exhausted: bool def _current_year_month() -> str: return datetime.now(timezone.utc).strftime("%Y-%m") _tier_manager = TierManager() async def check_folder_quota( *, user_id: str, tier: BillingTier, estimated_files: int, db: AsyncSession, ) -> None: """Raise QuotaExceeded if folder_max_files or folder_monthly_tokens would be violated. -1 in either feature means unlimited.""" max_files = _tier_manager.get_feature_value(tier, "folder_max_files") if max_files != -1 and estimated_files > max_files: raise QuotaExceeded( "max_files", f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.", ) cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens") if cap == -1: return ym = _current_year_month() row = ( await db.execute( select(MonthlyTokenUsage).where( MonthlyTokenUsage.user_id == user_id, MonthlyTokenUsage.year_month == ym, MonthlyTokenUsage.feature == "folder_index", ) ) ).scalar_one_or_none() used = row.tokens_used if row else 0 if used >= cap: raise QuotaExceeded( "monthly_tokens", f"Monthly token budget exhausted ({used}/{cap}); resets next month.", ) async def add_token_usage( *, user_id: str, feature: str, tokens: int, db: AsyncSession, cap: int | None = None, ) -> TokenUsageResult: """Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature). Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls back to a read-then-write on other engines (e.g. aiosqlite in tests). Returns post-update total and whether cap is exhausted. """ ym = _current_year_month() # Detect dialect to choose between native upsert and portable fallback. dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr] if dialect_name == "postgresql": # Native atomic upsert — production path. stmt = ( pg_insert(MonthlyTokenUsage) .values( user_id=user_id, year_month=ym, feature=feature, tokens_used=tokens, ) .on_conflict_do_update( index_elements=["user_id", "year_month", "feature"], set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens}, ) .returning(MonthlyTokenUsage.tokens_used) ) used: int = (await db.execute(stmt)).scalar_one() await db.commit() else: # Portable fallback — used in tests (SQLite) and any non-PG engine. row = ( await db.execute( select(MonthlyTokenUsage).where( MonthlyTokenUsage.user_id == user_id, MonthlyTokenUsage.year_month == ym, MonthlyTokenUsage.feature == feature, ) ) ).scalar_one_or_none() if row is None: row = MonthlyTokenUsage( user_id=user_id, year_month=ym, feature=feature, tokens_used=tokens, ) db.add(row) else: row.tokens_used += tokens await db.commit() await db.refresh(row) used = row.tokens_used exhausted = cap is not None and cap != -1 and used >= cap return TokenUsageResult(tokens_used=used, exhausted=exhausted)