Files
api/tests/test_ws_index_session.py
Roberto 582bf27deb feat(api): WS index_session frames + handlers
Add six v7 WsFrameType enum members (index_session_start/cancel/batch,
index_file_result/progress/done), wire dispatch in device_ws message loop,
and implement _handle_index_session_start/cancel/file_batch with per-file
summarisation, token accounting, and quota enforcement.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:22:20 +02:00

197 lines
7.2 KiB
Python

"""Tests for WS folder index_session handlers (Task 9).
Tests the three handler functions directly with a minimal fake WebSocket so
no real WS connection or LLM call is made.
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
import pytest_asyncio
from app.api.routes.device_ws import (
_handle_index_session_start,
_handle_index_file_batch,
_handle_index_session_cancel,
_index_sessions,
)
from app.billing.quota import add_token_usage
from app.core.folder_indexer import IndexResult
from app.models import MonthlyTokenUsage
from app.schemas import WsFrameType
from tests.conftest import TEST_USER_IDS
pytestmark = pytest.mark.asyncio
USER_ID = TEST_USER_IDS["free"]
POWER_USER_ID = TEST_USER_IDS["power"]
# ── Fake WebSocket ────────────────────────────────────────────────────
class _FakeWebSocket:
"""Minimal WebSocket stand-in that records send_text calls."""
def __init__(self) -> None:
self.sent: list[dict] = []
async def send_text(self, text: str) -> None:
self.sent.append(json.loads(text))
def sent_types(self) -> list[str]:
return [f["type"] for f in self.sent]
# ── Helpers ───────────────────────────────────────────────────────────
def _make_session_id() -> str:
import uuid
return str(uuid.uuid4())
def _fake_summarize_text_factory(summary: str = "A test summary.", tokens: int = 100):
"""Return an AsyncMock that resolves to a fixed IndexResult."""
async def _fake(**kwargs) -> IndexResult:
return IndexResult(summary=summary, tokens_used=tokens)
return _fake
# ── Fixtures ──────────────────────────────────────────────────────────
@pytest_asyncio.fixture(autouse=True)
async def _clean_sessions():
"""Ensure _index_sessions is empty before and after each test."""
_index_sessions.clear()
yield
_index_sessions.clear()
# ── Tests ─────────────────────────────────────────────────────────────
async def test_index_session_happy_path(db_session):
"""start + batch of 2 text files → 2 index_file_result + 1 progress + 1 done(completed)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
# Register session.
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"projectId": "proj-1",
"totalFiles": 2,
})
# Verify session was registered.
assert session_id in _index_sessions
assert _index_sessions[session_id]["total"] == 2
assert _index_sessions[session_id]["processed"] == 0
# No response frames expected for session_start.
assert ws.sent == []
# Send batch of 2 text files — patch summarize_text so no LLM call needed.
with patch(
"app.api.routes.device_ws._handle_index_file_batch.__globals__",
# We patch the module-level function in folder_indexer instead:
) if False else patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory()):
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
# Wire db_session into the context manager.
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_async_session.return_value = mock_cm
await _handle_index_file_batch(ws, USER_ID, {
"sessionId": session_id,
"files": [
{"relPath": "README.md", "kind": "text", "content": "hello", "ext": ".md"},
{"relPath": "notes.txt", "kind": "text", "content": "world", "ext": ".txt"},
],
})
types = ws.sent_types()
# Expect 2 file results + 1 progress + 1 done(completed).
assert types.count(WsFrameType.index_file_result) == 2
assert types.count(WsFrameType.index_session_progress) == 1
assert types.count(WsFrameType.index_session_done) == 1
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "completed"
progress_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_progress)
assert progress_frame["processed"] == 2
assert progress_frame["total"] == 2
# Verify session cleaned up.
assert session_id not in _index_sessions
async def test_index_session_cancel(db_session):
"""start then cancel → index_session_done(cancelled)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"totalFiles": 5,
})
assert session_id in _index_sessions
await _handle_index_session_cancel(ws, {"sessionId": session_id})
types = ws.sent_types()
assert WsFrameType.index_session_done in types
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "cancelled"
# Session should be cleaned up.
assert session_id not in _index_sessions
async def test_index_session_quota_exceeded(db_session):
"""Pre-fill usage to cap → batch one file → index_session_done(quota_exceeded)."""
ws = _FakeWebSocket()
session_id = _make_session_id()
# Pre-fill monthly token usage to the free-tier cap (100_000).
ym = datetime.now(timezone.utc).strftime("%Y-%m")
db_session.add(MonthlyTokenUsage(
user_id=USER_ID,
year_month=ym,
feature="folder_index",
tokens_used=100_000, # free tier cap exactly
))
await db_session.commit()
await _handle_index_session_start(ws, USER_ID, {
"sessionId": session_id,
"totalFiles": 1,
})
with patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory(tokens=1)):
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
mock_async_session.return_value = mock_cm
await _handle_index_file_batch(ws, USER_ID, {
"sessionId": session_id,
"files": [
{"relPath": "file.md", "kind": "text", "content": "content", "ext": ".md"},
],
})
types = ws.sent_types()
# Should have 1 file result (success) then done(quota_exceeded).
assert WsFrameType.index_file_result in types
assert WsFrameType.index_session_done in types
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
assert done_frame["status"] == "quota_exceeded"
# Session should be cleaned up.
assert session_id not in _index_sessions