Add six v7 WsFrameType enum members (index_session_start/cancel/batch, index_file_result/progress/done), wire dispatch in device_ws message loop, and implement _handle_index_session_start/cancel/file_batch with per-file summarisation, token accounting, and quota enforcement. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
197 lines
7.2 KiB
Python
197 lines
7.2 KiB
Python
"""Tests for WS folder index_session handlers (Task 9).
|
|
|
|
Tests the three handler functions directly with a minimal fake WebSocket so
|
|
no real WS connection or LLM call is made.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from app.api.routes.device_ws import (
|
|
_handle_index_session_start,
|
|
_handle_index_file_batch,
|
|
_handle_index_session_cancel,
|
|
_index_sessions,
|
|
)
|
|
from app.billing.quota import add_token_usage
|
|
from app.core.folder_indexer import IndexResult
|
|
from app.models import MonthlyTokenUsage
|
|
from app.schemas import WsFrameType
|
|
from tests.conftest import TEST_USER_IDS
|
|
|
|
pytestmark = pytest.mark.asyncio
|
|
|
|
USER_ID = TEST_USER_IDS["free"]
|
|
POWER_USER_ID = TEST_USER_IDS["power"]
|
|
|
|
|
|
# ── Fake WebSocket ────────────────────────────────────────────────────
|
|
|
|
class _FakeWebSocket:
|
|
"""Minimal WebSocket stand-in that records send_text calls."""
|
|
|
|
def __init__(self) -> None:
|
|
self.sent: list[dict] = []
|
|
|
|
async def send_text(self, text: str) -> None:
|
|
self.sent.append(json.loads(text))
|
|
|
|
def sent_types(self) -> list[str]:
|
|
return [f["type"] for f in self.sent]
|
|
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────
|
|
|
|
def _make_session_id() -> str:
|
|
import uuid
|
|
return str(uuid.uuid4())
|
|
|
|
|
|
def _fake_summarize_text_factory(summary: str = "A test summary.", tokens: int = 100):
|
|
"""Return an AsyncMock that resolves to a fixed IndexResult."""
|
|
async def _fake(**kwargs) -> IndexResult:
|
|
return IndexResult(summary=summary, tokens_used=tokens)
|
|
return _fake
|
|
|
|
|
|
# ── Fixtures ──────────────────────────────────────────────────────────
|
|
|
|
@pytest_asyncio.fixture(autouse=True)
|
|
async def _clean_sessions():
|
|
"""Ensure _index_sessions is empty before and after each test."""
|
|
_index_sessions.clear()
|
|
yield
|
|
_index_sessions.clear()
|
|
|
|
|
|
# ── Tests ─────────────────────────────────────────────────────────────
|
|
|
|
async def test_index_session_happy_path(db_session):
|
|
"""start + batch of 2 text files → 2 index_file_result + 1 progress + 1 done(completed)."""
|
|
ws = _FakeWebSocket()
|
|
session_id = _make_session_id()
|
|
|
|
# Register session.
|
|
await _handle_index_session_start(ws, USER_ID, {
|
|
"sessionId": session_id,
|
|
"projectId": "proj-1",
|
|
"totalFiles": 2,
|
|
})
|
|
|
|
# Verify session was registered.
|
|
assert session_id in _index_sessions
|
|
assert _index_sessions[session_id]["total"] == 2
|
|
assert _index_sessions[session_id]["processed"] == 0
|
|
# No response frames expected for session_start.
|
|
assert ws.sent == []
|
|
|
|
# Send batch of 2 text files — patch summarize_text so no LLM call needed.
|
|
with patch(
|
|
"app.api.routes.device_ws._handle_index_file_batch.__globals__",
|
|
# We patch the module-level function in folder_indexer instead:
|
|
) if False else patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory()):
|
|
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
|
|
# Wire db_session into the context manager.
|
|
mock_cm = AsyncMock()
|
|
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
|
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
|
mock_async_session.return_value = mock_cm
|
|
|
|
await _handle_index_file_batch(ws, USER_ID, {
|
|
"sessionId": session_id,
|
|
"files": [
|
|
{"relPath": "README.md", "kind": "text", "content": "hello", "ext": ".md"},
|
|
{"relPath": "notes.txt", "kind": "text", "content": "world", "ext": ".txt"},
|
|
],
|
|
})
|
|
|
|
types = ws.sent_types()
|
|
# Expect 2 file results + 1 progress + 1 done(completed).
|
|
assert types.count(WsFrameType.index_file_result) == 2
|
|
assert types.count(WsFrameType.index_session_progress) == 1
|
|
assert types.count(WsFrameType.index_session_done) == 1
|
|
|
|
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
|
assert done_frame["status"] == "completed"
|
|
|
|
progress_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_progress)
|
|
assert progress_frame["processed"] == 2
|
|
assert progress_frame["total"] == 2
|
|
|
|
# Verify session cleaned up.
|
|
assert session_id not in _index_sessions
|
|
|
|
|
|
async def test_index_session_cancel(db_session):
|
|
"""start then cancel → index_session_done(cancelled)."""
|
|
ws = _FakeWebSocket()
|
|
session_id = _make_session_id()
|
|
|
|
await _handle_index_session_start(ws, USER_ID, {
|
|
"sessionId": session_id,
|
|
"totalFiles": 5,
|
|
})
|
|
assert session_id in _index_sessions
|
|
|
|
await _handle_index_session_cancel(ws, {"sessionId": session_id})
|
|
|
|
types = ws.sent_types()
|
|
assert WsFrameType.index_session_done in types
|
|
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
|
assert done_frame["status"] == "cancelled"
|
|
|
|
# Session should be cleaned up.
|
|
assert session_id not in _index_sessions
|
|
|
|
|
|
async def test_index_session_quota_exceeded(db_session):
|
|
"""Pre-fill usage to cap → batch one file → index_session_done(quota_exceeded)."""
|
|
ws = _FakeWebSocket()
|
|
session_id = _make_session_id()
|
|
|
|
# Pre-fill monthly token usage to the free-tier cap (100_000).
|
|
ym = datetime.now(timezone.utc).strftime("%Y-%m")
|
|
db_session.add(MonthlyTokenUsage(
|
|
user_id=USER_ID,
|
|
year_month=ym,
|
|
feature="folder_index",
|
|
tokens_used=100_000, # free tier cap exactly
|
|
))
|
|
await db_session.commit()
|
|
|
|
await _handle_index_session_start(ws, USER_ID, {
|
|
"sessionId": session_id,
|
|
"totalFiles": 1,
|
|
})
|
|
|
|
with patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory(tokens=1)):
|
|
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
|
|
mock_cm = AsyncMock()
|
|
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
|
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
|
mock_async_session.return_value = mock_cm
|
|
|
|
await _handle_index_file_batch(ws, USER_ID, {
|
|
"sessionId": session_id,
|
|
"files": [
|
|
{"relPath": "file.md", "kind": "text", "content": "content", "ext": ".md"},
|
|
],
|
|
})
|
|
|
|
types = ws.sent_types()
|
|
# Should have 1 file result (success) then done(quota_exceeded).
|
|
assert WsFrameType.index_file_result in types
|
|
assert WsFrameType.index_session_done in types
|
|
|
|
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
|
assert done_frame["status"] == "quota_exceeded"
|
|
|
|
# Session should be cleaned up.
|
|
assert session_id not in _index_sessions
|