feat(api): WS index_session frames + handlers
Add six v7 WsFrameType enum members (index_session_start/cancel/batch, index_file_result/progress/done), wire dispatch in device_ws message loop, and implement _handle_index_session_start/cancel/file_batch with per-file summarisation, token accounting, and quota enforcement. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -57,6 +57,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
||||
|
||||
# ── v7 folder index session state ─────────────────────────────────────
|
||||
# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled }
|
||||
_index_sessions: dict[str, dict] = {}
|
||||
|
||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
||||
|
||||
@@ -180,6 +184,19 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
||||
_handle_journey_message(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.index_session_start:
|
||||
asyncio.create_task(
|
||||
_handle_index_session_start(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.index_file_batch:
|
||||
asyncio.create_task(
|
||||
_handle_index_file_batch(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.index_session_cancel:
|
||||
await _handle_index_session_cancel(websocket, frame)
|
||||
|
||||
elif frame_type == "pong":
|
||||
# Heartbeat ack — nothing to do, connection is alive.
|
||||
pass
|
||||
@@ -569,6 +586,173 @@ async def _handle_journey_message(
|
||||
clear_client_executor()
|
||||
|
||||
|
||||
# ── v7 Folder Index Handlers ──────────────────────────────────────────
|
||||
|
||||
|
||||
async def _handle_index_session_start(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Register a new folder index session. No response sent — client is declaring intent."""
|
||||
session_id: str = frame.get("sessionId") or frame.get("session_id", "")
|
||||
project_id: str | None = frame.get("projectId") or frame.get("project_id")
|
||||
total: int = int(frame.get("totalFiles", 0))
|
||||
|
||||
if not session_id:
|
||||
logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id)
|
||||
return
|
||||
|
||||
_index_sessions[session_id] = {
|
||||
"user_id": user_id,
|
||||
"project_id": project_id,
|
||||
"processed": 0,
|
||||
"total": total,
|
||||
"cancelled": False,
|
||||
}
|
||||
logger.info(
|
||||
"device_ws: index_session_start user=%s session=%s project=%s total=%d",
|
||||
user_id, session_id, project_id, total,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_index_session_cancel(
|
||||
websocket: WebSocket,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Mark a session as cancelled and emit index_session_done(cancelled)."""
|
||||
session_id: str = frame.get("sessionId") or frame.get("session_id", "")
|
||||
session = _index_sessions.get(session_id)
|
||||
if session:
|
||||
session["cancelled"] = True
|
||||
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_done,
|
||||
"sessionId": session_id,
|
||||
"status": "cancelled",
|
||||
}))
|
||||
_index_sessions.pop(session_id, None)
|
||||
logger.info("device_ws: index_session_cancel session=%s", session_id)
|
||||
|
||||
|
||||
async def _handle_index_file_batch(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Process a batch of files for an index session, streaming results back."""
|
||||
# Lazy imports to avoid heavy load at module startup.
|
||||
from app.core.folder_indexer import ( # noqa: PLC0415
|
||||
summarize_image,
|
||||
summarize_pdf,
|
||||
summarize_docx,
|
||||
summarize_text,
|
||||
)
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
from app.billing.quota import add_token_usage # noqa: PLC0415
|
||||
|
||||
session_id: str = frame.get("sessionId") or frame.get("session_id", "")
|
||||
files: list[dict] = frame.get("files", [])
|
||||
|
||||
session = _index_sessions.get(session_id)
|
||||
if not session or session.get("cancelled"):
|
||||
return
|
||||
|
||||
async with async_session() as db:
|
||||
tier = await tier_manager.get_tier(user_id, db)
|
||||
raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
||||
cap: int | None = None if raw_cap == -1 else raw_cap
|
||||
|
||||
for file_info in files:
|
||||
if session.get("cancelled"):
|
||||
return
|
||||
|
||||
rel_path: str = file_info.get("relPath", "")
|
||||
kind: str = file_info.get("kind", "text")
|
||||
content: str = file_info.get("content", "")
|
||||
ext: str = file_info.get("ext", "")
|
||||
mime: str = file_info.get("mime", "application/octet-stream")
|
||||
name: str = rel_path.split("/")[-1] or rel_path
|
||||
|
||||
try:
|
||||
if kind == "image":
|
||||
res = await summarize_image(image_b64=content, mime=mime)
|
||||
elif kind == "pdf":
|
||||
res = await summarize_pdf(pdf_b64=content, name=name)
|
||||
elif kind == "docx":
|
||||
res = await summarize_docx(docx_b64=content, name=name)
|
||||
else:
|
||||
res = await summarize_text(content=content, ext=ext, name=name)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"device_ws: index_file_batch summarize failed session=%s path=%s: %s",
|
||||
session_id, rel_path, exc,
|
||||
)
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_file_result,
|
||||
"sessionId": session_id,
|
||||
"relPath": rel_path,
|
||||
"summary": None,
|
||||
"tokensUsed": 0,
|
||||
"error": str(exc),
|
||||
}))
|
||||
session["processed"] += 1
|
||||
continue
|
||||
|
||||
# Account for token usage and check cap.
|
||||
usage = await add_token_usage(
|
||||
user_id=user_id,
|
||||
feature="folder_index",
|
||||
tokens=res.tokens_used,
|
||||
db=db,
|
||||
cap=cap,
|
||||
)
|
||||
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_file_result,
|
||||
"sessionId": session_id,
|
||||
"relPath": rel_path,
|
||||
"summary": res.summary,
|
||||
"tokensUsed": res.tokens_used,
|
||||
}))
|
||||
session["processed"] += 1
|
||||
|
||||
if usage.exhausted:
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_done,
|
||||
"sessionId": session_id,
|
||||
"status": "quota_exceeded",
|
||||
}))
|
||||
_index_sessions.pop(session_id, None)
|
||||
logger.info(
|
||||
"device_ws: index_session quota_exceeded user=%s session=%s",
|
||||
user_id, session_id,
|
||||
)
|
||||
return
|
||||
|
||||
# After processing the batch, emit progress.
|
||||
processed = session["processed"]
|
||||
total = session["total"]
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_progress,
|
||||
"sessionId": session_id,
|
||||
"processed": processed,
|
||||
"total": total,
|
||||
}))
|
||||
|
||||
if processed >= total:
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_done,
|
||||
"sessionId": session_id,
|
||||
"status": "completed",
|
||||
}))
|
||||
_index_sessions.pop(session_id, None)
|
||||
logger.info(
|
||||
"device_ws: index_session_done completed user=%s session=%s processed=%d",
|
||||
user_id, session_id, processed,
|
||||
)
|
||||
|
||||
|
||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||
|
||||
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||
|
||||
@@ -89,6 +89,13 @@ class WsFrameType(str, Enum):
|
||||
brief_request = "brief_request"
|
||||
# ── v6 task brief frame types ─────────────────────────────────────
|
||||
task_brief_request = "task_brief_request"
|
||||
# ── v7 folder index frame types ───────────────────────────────────
|
||||
index_session_start = "index_session_start"
|
||||
index_file_batch = "index_file_batch"
|
||||
index_session_cancel = "index_session_cancel"
|
||||
index_file_result = "index_file_result"
|
||||
index_session_progress = "index_session_progress"
|
||||
index_session_done = "index_session_done"
|
||||
|
||||
|
||||
class WsToolCall(BaseModel):
|
||||
|
||||
196
tests/test_ws_index_session.py
Normal file
196
tests/test_ws_index_session.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Tests for WS folder index_session handlers (Task 9).
|
||||
|
||||
Tests the three handler functions directly with a minimal fake WebSocket so
|
||||
no real WS connection or LLM call is made.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.api.routes.device_ws import (
|
||||
_handle_index_session_start,
|
||||
_handle_index_file_batch,
|
||||
_handle_index_session_cancel,
|
||||
_index_sessions,
|
||||
)
|
||||
from app.billing.quota import add_token_usage
|
||||
from app.core.folder_indexer import IndexResult
|
||||
from app.models import MonthlyTokenUsage
|
||||
from app.schemas import WsFrameType
|
||||
from tests.conftest import TEST_USER_IDS
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
USER_ID = TEST_USER_IDS["free"]
|
||||
POWER_USER_ID = TEST_USER_IDS["power"]
|
||||
|
||||
|
||||
# ── Fake WebSocket ────────────────────────────────────────────────────
|
||||
|
||||
class _FakeWebSocket:
|
||||
"""Minimal WebSocket stand-in that records send_text calls."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.sent: list[dict] = []
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
self.sent.append(json.loads(text))
|
||||
|
||||
def sent_types(self) -> list[str]:
|
||||
return [f["type"] for f in self.sent]
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _make_session_id() -> str:
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def _fake_summarize_text_factory(summary: str = "A test summary.", tokens: int = 100):
|
||||
"""Return an AsyncMock that resolves to a fixed IndexResult."""
|
||||
async def _fake(**kwargs) -> IndexResult:
|
||||
return IndexResult(summary=summary, tokens_used=tokens)
|
||||
return _fake
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def _clean_sessions():
|
||||
"""Ensure _index_sessions is empty before and after each test."""
|
||||
_index_sessions.clear()
|
||||
yield
|
||||
_index_sessions.clear()
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────────
|
||||
|
||||
async def test_index_session_happy_path(db_session):
|
||||
"""start + batch of 2 text files → 2 index_file_result + 1 progress + 1 done(completed)."""
|
||||
ws = _FakeWebSocket()
|
||||
session_id = _make_session_id()
|
||||
|
||||
# Register session.
|
||||
await _handle_index_session_start(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"projectId": "proj-1",
|
||||
"totalFiles": 2,
|
||||
})
|
||||
|
||||
# Verify session was registered.
|
||||
assert session_id in _index_sessions
|
||||
assert _index_sessions[session_id]["total"] == 2
|
||||
assert _index_sessions[session_id]["processed"] == 0
|
||||
# No response frames expected for session_start.
|
||||
assert ws.sent == []
|
||||
|
||||
# Send batch of 2 text files — patch summarize_text so no LLM call needed.
|
||||
with patch(
|
||||
"app.api.routes.device_ws._handle_index_file_batch.__globals__",
|
||||
# We patch the module-level function in folder_indexer instead:
|
||||
) if False else patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory()):
|
||||
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
|
||||
# Wire db_session into the context manager.
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_async_session.return_value = mock_cm
|
||||
|
||||
await _handle_index_file_batch(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"files": [
|
||||
{"relPath": "README.md", "kind": "text", "content": "hello", "ext": ".md"},
|
||||
{"relPath": "notes.txt", "kind": "text", "content": "world", "ext": ".txt"},
|
||||
],
|
||||
})
|
||||
|
||||
types = ws.sent_types()
|
||||
# Expect 2 file results + 1 progress + 1 done(completed).
|
||||
assert types.count(WsFrameType.index_file_result) == 2
|
||||
assert types.count(WsFrameType.index_session_progress) == 1
|
||||
assert types.count(WsFrameType.index_session_done) == 1
|
||||
|
||||
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||
assert done_frame["status"] == "completed"
|
||||
|
||||
progress_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_progress)
|
||||
assert progress_frame["processed"] == 2
|
||||
assert progress_frame["total"] == 2
|
||||
|
||||
# Verify session cleaned up.
|
||||
assert session_id not in _index_sessions
|
||||
|
||||
|
||||
async def test_index_session_cancel(db_session):
|
||||
"""start then cancel → index_session_done(cancelled)."""
|
||||
ws = _FakeWebSocket()
|
||||
session_id = _make_session_id()
|
||||
|
||||
await _handle_index_session_start(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"totalFiles": 5,
|
||||
})
|
||||
assert session_id in _index_sessions
|
||||
|
||||
await _handle_index_session_cancel(ws, {"sessionId": session_id})
|
||||
|
||||
types = ws.sent_types()
|
||||
assert WsFrameType.index_session_done in types
|
||||
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||
assert done_frame["status"] == "cancelled"
|
||||
|
||||
# Session should be cleaned up.
|
||||
assert session_id not in _index_sessions
|
||||
|
||||
|
||||
async def test_index_session_quota_exceeded(db_session):
|
||||
"""Pre-fill usage to cap → batch one file → index_session_done(quota_exceeded)."""
|
||||
ws = _FakeWebSocket()
|
||||
session_id = _make_session_id()
|
||||
|
||||
# Pre-fill monthly token usage to the free-tier cap (100_000).
|
||||
ym = datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
db_session.add(MonthlyTokenUsage(
|
||||
user_id=USER_ID,
|
||||
year_month=ym,
|
||||
feature="folder_index",
|
||||
tokens_used=100_000, # free tier cap exactly
|
||||
))
|
||||
await db_session.commit()
|
||||
|
||||
await _handle_index_session_start(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"totalFiles": 1,
|
||||
})
|
||||
|
||||
with patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory(tokens=1)):
|
||||
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
|
||||
mock_cm = AsyncMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_async_session.return_value = mock_cm
|
||||
|
||||
await _handle_index_file_batch(ws, USER_ID, {
|
||||
"sessionId": session_id,
|
||||
"files": [
|
||||
{"relPath": "file.md", "kind": "text", "content": "content", "ext": ".md"},
|
||||
],
|
||||
})
|
||||
|
||||
types = ws.sent_types()
|
||||
# Should have 1 file result (success) then done(quota_exceeded).
|
||||
assert WsFrameType.index_file_result in types
|
||||
assert WsFrameType.index_session_done in types
|
||||
|
||||
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||
assert done_frame["status"] == "quota_exceeded"
|
||||
|
||||
# Session should be cleaned up.
|
||||
assert session_id not in _index_sessions
|
||||
Reference in New Issue
Block a user