feat(api): WS index_session frames + handlers
Add six v7 WsFrameType enum members (index_session_start/cancel/batch, index_file_result/progress/done), wire dispatch in device_ws message loop, and implement _handle_index_session_start/cancel/file_batch with per-file summarisation, token accounting, and quota enforcement. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
196
tests/test_ws_index_session.py
Normal file
196
tests/test_ws_index_session.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Tests for WS folder index_session handlers (Task 9).
|
||||
|
||||
Tests the three handler functions directly with a minimal fake WebSocket so
|
||||
no real WS connection or LLM call is made.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.api.routes.device_ws import (
|
||||
_handle_index_session_start,
|
||||
_handle_index_file_batch,
|
||||
_handle_index_session_cancel,
|
||||
_index_sessions,
|
||||
)
|
||||
from app.billing.quota import add_token_usage
|
||||
from app.core.folder_indexer import IndexResult
|
||||
from app.models import MonthlyTokenUsage
|
||||
from app.schemas import WsFrameType
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
USER_ID = TEST_USER_IDS["free"]
|
||||
POWER_USER_ID = TEST_USER_IDS["power"]
|
||||
|
||||
|
||||
# ── Fake WebSocket ────────────────────────────────────────────────────
|
||||
|
||||
class _FakeWebSocket:
|
||||
"""Minimal WebSocket stand-in that records send_text calls."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.sent: list[dict] = []
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
self.sent.append(json.loads(text))
|
||||
|
||||
def sent_types(self) -> list[str]:
|
||||
return [f["type"] for f in self.sent]
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _make_session_id() -> str:
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def _fake_summarize_text_factory(summary: str = "A test summary.", tokens: int = 100):
|
||||
"""Return an AsyncMock that resolves to a fixed IndexResult."""
|
||||
async def _fake(**kwargs) -> IndexResult:
|
||||
return IndexResult(summary=summary, tokens_used=tokens)
|
||||
return _fake
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def _clean_sessions():
|
||||
"""Ensure _index_sessions is empty before and after each test."""
|
||||
_index_sessions.clear()
|
||||
yield
|
||||
_index_sessions.clear()
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────────
|
||||
|
||||
async def test_index_session_happy_path(db_session):
|
||||
"""start + batch of 2 text files → 2 index_file_result + 1 progress + 1 done(completed)."""
|
||||
ws = _FakeWebSocket()
|
||||
session_id = _make_session_id()
|
||||
|
||||
# Register session.
|
||||
await _handle_index_session_start(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"projectId": "proj-1",
|
||||
"totalFiles": 2,
|
||||
})
|
||||
|
||||
# Verify session was registered.
|
||||
assert session_id in _index_sessions
|
||||
assert _index_sessions[session_id]["total"] == 2
|
||||
assert _index_sessions[session_id]["processed"] == 0
|
||||
# No response frames expected for session_start.
|
||||
assert ws.sent == []
|
||||
|
||||
# Send batch of 2 text files — patch summarize_text so no LLM call needed.
|
||||
with patch(
|
||||
"app.api.routes.device_ws._handle_index_file_batch.__globals__",
|
||||
# We patch the module-level function in folder_indexer instead:
|
||||
) if False else patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory()):
|
||||
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
|
||||
# Wire db_session into the context manager.
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_async_session.return_value = mock_cm
|
||||
|
||||
await _handle_index_file_batch(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"files": [
|
||||
{"relPath": "README.md", "kind": "text", "content": "hello", "ext": ".md"},
|
||||
{"relPath": "notes.txt", "kind": "text", "content": "world", "ext": ".txt"},
|
||||
],
|
||||
})
|
||||
|
||||
types = ws.sent_types()
|
||||
# Expect 2 file results + 1 progress + 1 done(completed).
|
||||
assert types.count(WsFrameType.index_file_result) == 2
|
||||
assert types.count(WsFrameType.index_session_progress) == 1
|
||||
assert types.count(WsFrameType.index_session_done) == 1
|
||||
|
||||
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||
assert done_frame["status"] == "completed"
|
||||
|
||||
progress_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_progress)
|
||||
assert progress_frame["processed"] == 2
|
||||
assert progress_frame["total"] == 2
|
||||
|
||||
# Verify session cleaned up.
|
||||
assert session_id not in _index_sessions
|
||||
|
||||
|
||||
async def test_index_session_cancel(db_session):
|
||||
"""start then cancel → index_session_done(cancelled)."""
|
||||
ws = _FakeWebSocket()
|
||||
session_id = _make_session_id()
|
||||
|
||||
await _handle_index_session_start(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"totalFiles": 5,
|
||||
})
|
||||
assert session_id in _index_sessions
|
||||
|
||||
await _handle_index_session_cancel(ws, {"sessionId": session_id})
|
||||
|
||||
types = ws.sent_types()
|
||||
assert WsFrameType.index_session_done in types
|
||||
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||
assert done_frame["status"] == "cancelled"
|
||||
|
||||
# Session should be cleaned up.
|
||||
assert session_id not in _index_sessions
|
||||
|
||||
|
||||
async def test_index_session_quota_exceeded(db_session):
|
||||
"""Pre-fill usage to cap → batch one file → index_session_done(quota_exceeded)."""
|
||||
ws = _FakeWebSocket()
|
||||
session_id = _make_session_id()
|
||||
|
||||
# Pre-fill monthly token usage to the free-tier cap (100_000).
|
||||
ym = datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
db_session.add(MonthlyTokenUsage(
|
||||
user_id=USER_ID,
|
||||
year_month=ym,
|
||||
feature="folder_index",
|
||||
tokens_used=100_000, # free tier cap exactly
|
||||
))
|
||||
await db_session.commit()
|
||||
|
||||
await _handle_index_session_start(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"totalFiles": 1,
|
||||
})
|
||||
|
||||
with patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory(tokens=1)):
|
||||
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_async_session.return_value = mock_cm
|
||||
|
||||
await _handle_index_file_batch(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"files": [
|
||||
{"relPath": "file.md", "kind": "text", "content": "content", "ext": ".md"},
|
||||
],
|
||||
})
|
||||
|
||||
types = ws.sent_types()
|
||||
# Should have 1 file result (success) then done(quota_exceeded).
|
||||
assert WsFrameType.index_file_result in types
|
||||
assert WsFrameType.index_session_done in types
|
||||
|
||||
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||
assert done_frame["status"] == "quota_exceeded"
|
||||
|
||||
# Session should be cleaned up.
|
||||
assert session_id not in _index_sessions
|
||||
Reference in New Issue
Block a user