From 582bf27deb13ef018418e527bf050b5649fab619 Mon Sep 17 00:00:00 2001 From: Roberto Date: Tue, 12 May 2026 11:22:20 +0200 Subject: [PATCH] feat(api): WS index_session frames + handlers Add six v7 WsFrameType enum members (index_session_start/cancel/batch, index_file_result/progress/done), wire dispatch in device_ws message loop, and implement _handle_index_session_start/cancel/file_batch with per-file summarisation, token accounting, and quota enforcement. Co-Authored-By: Claude Sonnet 4.6 --- app/api/routes/device_ws.py | 184 +++++++++++++++++++++++++++++++ app/schemas.py | 7 ++ tests/test_ws_index_session.py | 196 +++++++++++++++++++++++++++++++++ 3 files changed, 387 insertions(+) create mode 100644 tests/test_ws_index_session.py diff --git a/app/api/routes/device_ws.py b/app/api/routes/device_ws.py index 91de0f4..878de4a 100644 --- a/app/api/routes/device_ws.py +++ b/app/api/routes/device_ws.py @@ -57,6 +57,10 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/ws", tags=["device-ws"]) +# ── v7 folder index session state ───────────────────────────────────── +# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled } +_index_sessions: dict[str, dict] = {} + _HEARTBEAT_INTERVAL = 30 # seconds _PONG_TIMEOUT = 10 # seconds — grace window after a ping @@ -180,6 +184,19 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None: _handle_journey_message(websocket, user_id, frame) ) + elif frame_type == WsFrameType.index_session_start: + asyncio.create_task( + _handle_index_session_start(websocket, user_id, frame) + ) + + elif frame_type == WsFrameType.index_file_batch: + asyncio.create_task( + _handle_index_file_batch(websocket, user_id, frame) + ) + + elif frame_type == WsFrameType.index_session_cancel: + await _handle_index_session_cancel(websocket, frame) + elif frame_type == "pong": # Heartbeat ack — nothing to do, connection is alive. pass @@ -569,6 +586,173 @@ async def _handle_journey_message( clear_client_executor() +# ── v7 Folder Index Handlers ────────────────────────────────────────── + + +async def _handle_index_session_start( + websocket: WebSocket, + user_id: str, + frame: dict, +) -> None: + """Register a new folder index session. No response sent — client is declaring intent.""" + session_id: str = frame.get("sessionId") or frame.get("session_id", "") + project_id: str | None = frame.get("projectId") or frame.get("project_id") + total: int = int(frame.get("totalFiles", 0)) + + if not session_id: + logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id) + return + + _index_sessions[session_id] = { + "user_id": user_id, + "project_id": project_id, + "processed": 0, + "total": total, + "cancelled": False, + } + logger.info( + "device_ws: index_session_start user=%s session=%s project=%s total=%d", + user_id, session_id, project_id, total, + ) + + +async def _handle_index_session_cancel( + websocket: WebSocket, + frame: dict, +) -> None: + """Mark a session as cancelled and emit index_session_done(cancelled).""" + session_id: str = frame.get("sessionId") or frame.get("session_id", "") + session = _index_sessions.get(session_id) + if session: + session["cancelled"] = True + + await websocket.send_text(json.dumps({ + "type": WsFrameType.index_session_done, + "sessionId": session_id, + "status": "cancelled", + })) + _index_sessions.pop(session_id, None) + logger.info("device_ws: index_session_cancel session=%s", session_id) + + +async def _handle_index_file_batch( + websocket: WebSocket, + user_id: str, + frame: dict, +) -> None: + """Process a batch of files for an index session, streaming results back.""" + # Lazy imports to avoid heavy load at module startup. + from app.core.folder_indexer import ( # noqa: PLC0415 + summarize_image, + summarize_pdf, + summarize_docx, + summarize_text, + ) + from app.billing.tier_manager import tier_manager # noqa: PLC0415 + from app.billing.quota import add_token_usage # noqa: PLC0415 + + session_id: str = frame.get("sessionId") or frame.get("session_id", "") + files: list[dict] = frame.get("files", []) + + session = _index_sessions.get(session_id) + if not session or session.get("cancelled"): + return + + async with async_session() as db: + tier = await tier_manager.get_tier(user_id, db) + raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens") + cap: int | None = None if raw_cap == -1 else raw_cap + + for file_info in files: + if session.get("cancelled"): + return + + rel_path: str = file_info.get("relPath", "") + kind: str = file_info.get("kind", "text") + content: str = file_info.get("content", "") + ext: str = file_info.get("ext", "") + mime: str = file_info.get("mime", "application/octet-stream") + name: str = rel_path.split("/")[-1] or rel_path + + try: + if kind == "image": + res = await summarize_image(image_b64=content, mime=mime) + elif kind == "pdf": + res = await summarize_pdf(pdf_b64=content, name=name) + elif kind == "docx": + res = await summarize_docx(docx_b64=content, name=name) + else: + res = await summarize_text(content=content, ext=ext, name=name) + except Exception as exc: + logger.warning( + "device_ws: index_file_batch summarize failed session=%s path=%s: %s", + session_id, rel_path, exc, + ) + await websocket.send_text(json.dumps({ + "type": WsFrameType.index_file_result, + "sessionId": session_id, + "relPath": rel_path, + "summary": None, + "tokensUsed": 0, + "error": str(exc), + })) + session["processed"] += 1 + continue + + # Account for token usage and check cap. + usage = await add_token_usage( + user_id=user_id, + feature="folder_index", + tokens=res.tokens_used, + db=db, + cap=cap, + ) + + await websocket.send_text(json.dumps({ + "type": WsFrameType.index_file_result, + "sessionId": session_id, + "relPath": rel_path, + "summary": res.summary, + "tokensUsed": res.tokens_used, + })) + session["processed"] += 1 + + if usage.exhausted: + await websocket.send_text(json.dumps({ + "type": WsFrameType.index_session_done, + "sessionId": session_id, + "status": "quota_exceeded", + })) + _index_sessions.pop(session_id, None) + logger.info( + "device_ws: index_session quota_exceeded user=%s session=%s", + user_id, session_id, + ) + return + + # After processing the batch, emit progress. + processed = session["processed"] + total = session["total"] + await websocket.send_text(json.dumps({ + "type": WsFrameType.index_session_progress, + "sessionId": session_id, + "processed": processed, + "total": total, + })) + + if processed >= total: + await websocket.send_text(json.dumps({ + "type": WsFrameType.index_session_done, + "sessionId": session_id, + "status": "completed", + })) + _index_sessions.pop(session_id, None) + logger.info( + "device_ws: index_session_done completed user=%s session=%s processed=%d", + user_id, session_id, processed, + ) + + # ── Heartbeat ───────────────────────────────────────────────────────── async def _heartbeat_loop(websocket: WebSocket) -> None: diff --git a/app/schemas.py b/app/schemas.py index 6bf1db5..ba4d283 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -89,6 +89,13 @@ class WsFrameType(str, Enum): brief_request = "brief_request" # ── v6 task brief frame types ───────────────────────────────────── task_brief_request = "task_brief_request" + # ── v7 folder index frame types ─────────────────────────────────── + index_session_start = "index_session_start" + index_file_batch = "index_file_batch" + index_session_cancel = "index_session_cancel" + index_file_result = "index_file_result" + index_session_progress = "index_session_progress" + index_session_done = "index_session_done" class WsToolCall(BaseModel): diff --git a/tests/test_ws_index_session.py b/tests/test_ws_index_session.py new file mode 100644 index 0000000..48eaeca --- /dev/null +++ b/tests/test_ws_index_session.py @@ -0,0 +1,196 @@ +"""Tests for WS folder index_session handlers (Task 9). + +Tests the three handler functions directly with a minimal fake WebSocket so +no real WS connection or LLM call is made. +""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone +from unittest.mock import AsyncMock, patch + +import pytest +import pytest_asyncio + +from app.api.routes.device_ws import ( + _handle_index_session_start, + _handle_index_file_batch, + _handle_index_session_cancel, + _index_sessions, +) +from app.billing.quota import add_token_usage +from app.core.folder_indexer import IndexResult +from app.models import MonthlyTokenUsage +from app.schemas import WsFrameType +from tests.conftest import TEST_USER_IDS + +pytestmark = pytest.mark.asyncio + +USER_ID = TEST_USER_IDS["free"] +POWER_USER_ID = TEST_USER_IDS["power"] + + +# ── Fake WebSocket ──────────────────────────────────────────────────── + +class _FakeWebSocket: + """Minimal WebSocket stand-in that records send_text calls.""" + + def __init__(self) -> None: + self.sent: list[dict] = [] + + async def send_text(self, text: str) -> None: + self.sent.append(json.loads(text)) + + def sent_types(self) -> list[str]: + return [f["type"] for f in self.sent] + + +# ── Helpers ─────────────────────────────────────────────────────────── + +def _make_session_id() -> str: + import uuid + return str(uuid.uuid4()) + + +def _fake_summarize_text_factory(summary: str = "A test summary.", tokens: int = 100): + """Return an AsyncMock that resolves to a fixed IndexResult.""" + async def _fake(**kwargs) -> IndexResult: + return IndexResult(summary=summary, tokens_used=tokens) + return _fake + + +# ── Fixtures ────────────────────────────────────────────────────────── + +@pytest_asyncio.fixture(autouse=True) +async def _clean_sessions(): + """Ensure _index_sessions is empty before and after each test.""" + _index_sessions.clear() + yield + _index_sessions.clear() + + +# ── Tests ───────────────────────────────────────────────────────────── + +async def test_index_session_happy_path(db_session): + """start + batch of 2 text files → 2 index_file_result + 1 progress + 1 done(completed).""" + ws = _FakeWebSocket() + session_id = _make_session_id() + + # Register session. + await _handle_index_session_start(ws, USER_ID, { + "sessionId": session_id, + "projectId": "proj-1", + "totalFiles": 2, + }) + + # Verify session was registered. + assert session_id in _index_sessions + assert _index_sessions[session_id]["total"] == 2 + assert _index_sessions[session_id]["processed"] == 0 + # No response frames expected for session_start. + assert ws.sent == [] + + # Send batch of 2 text files — patch summarize_text so no LLM call needed. + with patch( + "app.api.routes.device_ws._handle_index_file_batch.__globals__", + # We patch the module-level function in folder_indexer instead: + ) if False else patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory()): + with patch("app.api.routes.device_ws.async_session") as mock_async_session: + # Wire db_session into the context manager. + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=db_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_async_session.return_value = mock_cm + + await _handle_index_file_batch(ws, USER_ID, { + "sessionId": session_id, + "files": [ + {"relPath": "README.md", "kind": "text", "content": "hello", "ext": ".md"}, + {"relPath": "notes.txt", "kind": "text", "content": "world", "ext": ".txt"}, + ], + }) + + types = ws.sent_types() + # Expect 2 file results + 1 progress + 1 done(completed). + assert types.count(WsFrameType.index_file_result) == 2 + assert types.count(WsFrameType.index_session_progress) == 1 + assert types.count(WsFrameType.index_session_done) == 1 + + done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done) + assert done_frame["status"] == "completed" + + progress_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_progress) + assert progress_frame["processed"] == 2 + assert progress_frame["total"] == 2 + + # Verify session cleaned up. + assert session_id not in _index_sessions + + +async def test_index_session_cancel(db_session): + """start then cancel → index_session_done(cancelled).""" + ws = _FakeWebSocket() + session_id = _make_session_id() + + await _handle_index_session_start(ws, USER_ID, { + "sessionId": session_id, + "totalFiles": 5, + }) + assert session_id in _index_sessions + + await _handle_index_session_cancel(ws, {"sessionId": session_id}) + + types = ws.sent_types() + assert WsFrameType.index_session_done in types + done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done) + assert done_frame["status"] == "cancelled" + + # Session should be cleaned up. + assert session_id not in _index_sessions + + +async def test_index_session_quota_exceeded(db_session): + """Pre-fill usage to cap → batch one file → index_session_done(quota_exceeded).""" + ws = _FakeWebSocket() + session_id = _make_session_id() + + # Pre-fill monthly token usage to the free-tier cap (100_000). + ym = datetime.now(timezone.utc).strftime("%Y-%m") + db_session.add(MonthlyTokenUsage( + user_id=USER_ID, + year_month=ym, + feature="folder_index", + tokens_used=100_000, # free tier cap exactly + )) + await db_session.commit() + + await _handle_index_session_start(ws, USER_ID, { + "sessionId": session_id, + "totalFiles": 1, + }) + + with patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory(tokens=1)): + with patch("app.api.routes.device_ws.async_session") as mock_async_session: + mock_cm = AsyncMock() + mock_cm.__aenter__ = AsyncMock(return_value=db_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + mock_async_session.return_value = mock_cm + + await _handle_index_file_batch(ws, USER_ID, { + "sessionId": session_id, + "files": [ + {"relPath": "file.md", "kind": "text", "content": "content", "ext": ".md"}, + ], + }) + + types = ws.sent_types() + # Should have 1 file result (success) then done(quota_exceeded). + assert WsFrameType.index_file_result in types + assert WsFrameType.index_session_done in types + + done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done) + assert done_frame["status"] == "quota_exceeded" + + # Session should be cleaned up. + assert session_id not in _index_sessions