"""Tests for WS folder index_session handlers (Task 9). Tests the three handler functions directly with a minimal fake WebSocket so no real WS connection or LLM call is made. """ from __future__ import annotations import json from datetime import datetime, timezone from unittest.mock import AsyncMock, patch import pytest import pytest_asyncio from app.api.routes.device_ws import ( _handle_index_session_start, _handle_index_file_batch, _handle_index_session_cancel, _index_sessions, ) from app.billing.quota import add_token_usage from app.core.folder_indexer import IndexResult from app.models import MonthlyTokenUsage from app.schemas import WsFrameType from tests.conftest import TEST_USER_IDS pytestmark = pytest.mark.asyncio USER_ID = TEST_USER_IDS["free"] POWER_USER_ID = TEST_USER_IDS["power"] # ── Fake WebSocket ──────────────────────────────────────────────────── class _FakeWebSocket: """Minimal WebSocket stand-in that records send_text calls.""" def __init__(self) -> None: self.sent: list[dict] = [] async def send_text(self, text: str) -> None: self.sent.append(json.loads(text)) def sent_types(self) -> list[str]: return [f["type"] for f in self.sent] # ── Helpers ─────────────────────────────────────────────────────────── def _make_session_id() -> str: import uuid return str(uuid.uuid4()) def _fake_summarize_text_factory(summary: str = "A test summary.", tokens: int = 100): """Return an AsyncMock that resolves to a fixed IndexResult.""" async def _fake(**kwargs) -> IndexResult: return IndexResult(summary=summary, tokens_used=tokens) return _fake # ── Fixtures ────────────────────────────────────────────────────────── @pytest_asyncio.fixture(autouse=True) async def _clean_sessions(): """Ensure _index_sessions is empty before and after each test.""" _index_sessions.clear() yield _index_sessions.clear() # ── Tests ───────────────────────────────────────────────────────────── async def test_index_session_happy_path(db_session): """start + batch of 2 text files → 2 index_file_result + 1 progress + 1 done(completed).""" ws = _FakeWebSocket() session_id = _make_session_id() # Register session. await _handle_index_session_start(ws, USER_ID, { "sessionId": session_id, "projectId": "proj-1", "totalFiles": 2, }) # Verify session was registered. assert session_id in _index_sessions assert _index_sessions[session_id]["total"] == 2 assert _index_sessions[session_id]["processed"] == 0 # No response frames expected for session_start. assert ws.sent == [] # Send batch of 2 text files — patch summarize_text so no LLM call needed. with patch( "app.api.routes.device_ws._handle_index_file_batch.__globals__", # We patch the module-level function in folder_indexer instead: ) if False else patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory()): with patch("app.api.routes.device_ws.async_session") as mock_async_session: # Wire db_session into the context manager. mock_cm = AsyncMock() mock_cm.__aenter__ = AsyncMock(return_value=db_session) mock_cm.__aexit__ = AsyncMock(return_value=False) mock_async_session.return_value = mock_cm await _handle_index_file_batch(ws, USER_ID, { "sessionId": session_id, "files": [ {"relPath": "README.md", "kind": "text", "content": "hello", "ext": ".md"}, {"relPath": "notes.txt", "kind": "text", "content": "world", "ext": ".txt"}, ], }) types = ws.sent_types() # Expect 2 file results + 1 progress + 1 done(completed). assert types.count(WsFrameType.index_file_result) == 2 assert types.count(WsFrameType.index_session_progress) == 1 assert types.count(WsFrameType.index_session_done) == 1 done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done) assert done_frame["status"] == "completed" progress_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_progress) assert progress_frame["processed"] == 2 assert progress_frame["total"] == 2 # Verify session cleaned up. assert session_id not in _index_sessions async def test_index_session_cancel(db_session): """start then cancel → index_session_done(cancelled).""" ws = _FakeWebSocket() session_id = _make_session_id() await _handle_index_session_start(ws, USER_ID, { "sessionId": session_id, "totalFiles": 5, }) assert session_id in _index_sessions await _handle_index_session_cancel(ws, {"sessionId": session_id}) types = ws.sent_types() assert WsFrameType.index_session_done in types done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done) assert done_frame["status"] == "cancelled" # Session should be cleaned up. assert session_id not in _index_sessions async def test_index_session_quota_exceeded(db_session): """Pre-fill usage to cap → batch one file → index_session_done(quota_exceeded).""" ws = _FakeWebSocket() session_id = _make_session_id() # Pre-fill monthly token usage to the free-tier cap (100_000). ym = datetime.now(timezone.utc).strftime("%Y-%m") db_session.add(MonthlyTokenUsage( user_id=USER_ID, year_month=ym, feature="folder_index", tokens_used=100_000, # free tier cap exactly )) await db_session.commit() await _handle_index_session_start(ws, USER_ID, { "sessionId": session_id, "totalFiles": 1, }) with patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory(tokens=1)): with patch("app.api.routes.device_ws.async_session") as mock_async_session: mock_cm = AsyncMock() mock_cm.__aenter__ = AsyncMock(return_value=db_session) mock_cm.__aexit__ = AsyncMock(return_value=False) mock_async_session.return_value = mock_cm await _handle_index_file_batch(ws, USER_ID, { "sessionId": session_id, "files": [ {"relPath": "file.md", "kind": "text", "content": "content", "ext": ".md"}, ], }) types = ws.sent_types() # Should have 1 file result (success) then done(quota_exceeded). assert WsFrameType.index_file_result in types assert WsFrameType.index_session_done in types done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done) assert done_frame["status"] == "quota_exceeded" # Session should be cleaned up. assert session_id not in _index_sessions