feat(api): WS index_session frames + handlers
Add six v7 WsFrameType enum members (index_session_start/cancel/batch, index_file_result/progress/done), wire dispatch in device_ws message loop, and implement _handle_index_session_start/cancel/file_batch with per-file summarisation, token accounting, and quota enforcement. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -57,6 +57,10 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
||||||
|
|
||||||
|
# ── v7 folder index session state ─────────────────────────────────────
|
||||||
|
# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled }
|
||||||
|
_index_sessions: dict[str, dict] = {}
|
||||||
|
|
||||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||||
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
||||||
|
|
||||||
@@ -180,6 +184,19 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
_handle_journey_message(websocket, user_id, frame)
|
_handle_journey_message(websocket, user_id, frame)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.index_session_start:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_index_session_start(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.index_file_batch:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_index_file_batch(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.index_session_cancel:
|
||||||
|
await _handle_index_session_cancel(websocket, frame)
|
||||||
|
|
||||||
elif frame_type == "pong":
|
elif frame_type == "pong":
|
||||||
# Heartbeat ack — nothing to do, connection is alive.
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
pass
|
pass
|
||||||
@@ -569,6 +586,173 @@ async def _handle_journey_message(
|
|||||||
clear_client_executor()
|
clear_client_executor()
|
||||||
|
|
||||||
|
|
||||||
|
# ── v7 Folder Index Handlers ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_index_session_start(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Register a new folder index session. No response sent — client is declaring intent."""
|
||||||
|
session_id: str = frame.get("sessionId") or frame.get("session_id", "")
|
||||||
|
project_id: str | None = frame.get("projectId") or frame.get("project_id")
|
||||||
|
total: int = int(frame.get("totalFiles", 0))
|
||||||
|
|
||||||
|
if not session_id:
|
||||||
|
logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
_index_sessions[session_id] = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"project_id": project_id,
|
||||||
|
"processed": 0,
|
||||||
|
"total": total,
|
||||||
|
"cancelled": False,
|
||||||
|
}
|
||||||
|
logger.info(
|
||||||
|
"device_ws: index_session_start user=%s session=%s project=%s total=%d",
|
||||||
|
user_id, session_id, project_id, total,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_index_session_cancel(
|
||||||
|
websocket: WebSocket,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Mark a session as cancelled and emit index_session_done(cancelled)."""
|
||||||
|
session_id: str = frame.get("sessionId") or frame.get("session_id", "")
|
||||||
|
session = _index_sessions.get(session_id)
|
||||||
|
if session:
|
||||||
|
session["cancelled"] = True
|
||||||
|
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_session_done,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"status": "cancelled",
|
||||||
|
}))
|
||||||
|
_index_sessions.pop(session_id, None)
|
||||||
|
logger.info("device_ws: index_session_cancel session=%s", session_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_index_file_batch(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Process a batch of files for an index session, streaming results back."""
|
||||||
|
# Lazy imports to avoid heavy load at module startup.
|
||||||
|
from app.core.folder_indexer import ( # noqa: PLC0415
|
||||||
|
summarize_image,
|
||||||
|
summarize_pdf,
|
||||||
|
summarize_docx,
|
||||||
|
summarize_text,
|
||||||
|
)
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
from app.billing.quota import add_token_usage # noqa: PLC0415
|
||||||
|
|
||||||
|
session_id: str = frame.get("sessionId") or frame.get("session_id", "")
|
||||||
|
files: list[dict] = frame.get("files", [])
|
||||||
|
|
||||||
|
session = _index_sessions.get(session_id)
|
||||||
|
if not session or session.get("cancelled"):
|
||||||
|
return
|
||||||
|
|
||||||
|
async with async_session() as db:
|
||||||
|
tier = await tier_manager.get_tier(user_id, db)
|
||||||
|
raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
||||||
|
cap: int | None = None if raw_cap == -1 else raw_cap
|
||||||
|
|
||||||
|
for file_info in files:
|
||||||
|
if session.get("cancelled"):
|
||||||
|
return
|
||||||
|
|
||||||
|
rel_path: str = file_info.get("relPath", "")
|
||||||
|
kind: str = file_info.get("kind", "text")
|
||||||
|
content: str = file_info.get("content", "")
|
||||||
|
ext: str = file_info.get("ext", "")
|
||||||
|
mime: str = file_info.get("mime", "application/octet-stream")
|
||||||
|
name: str = rel_path.split("/")[-1] or rel_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
if kind == "image":
|
||||||
|
res = await summarize_image(image_b64=content, mime=mime)
|
||||||
|
elif kind == "pdf":
|
||||||
|
res = await summarize_pdf(pdf_b64=content, name=name)
|
||||||
|
elif kind == "docx":
|
||||||
|
res = await summarize_docx(docx_b64=content, name=name)
|
||||||
|
else:
|
||||||
|
res = await summarize_text(content=content, ext=ext, name=name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: index_file_batch summarize failed session=%s path=%s: %s",
|
||||||
|
session_id, rel_path, exc,
|
||||||
|
)
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_file_result,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"relPath": rel_path,
|
||||||
|
"summary": None,
|
||||||
|
"tokensUsed": 0,
|
||||||
|
"error": str(exc),
|
||||||
|
}))
|
||||||
|
session["processed"] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Account for token usage and check cap.
|
||||||
|
usage = await add_token_usage(
|
||||||
|
user_id=user_id,
|
||||||
|
feature="folder_index",
|
||||||
|
tokens=res.tokens_used,
|
||||||
|
db=db,
|
||||||
|
cap=cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_file_result,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"relPath": rel_path,
|
||||||
|
"summary": res.summary,
|
||||||
|
"tokensUsed": res.tokens_used,
|
||||||
|
}))
|
||||||
|
session["processed"] += 1
|
||||||
|
|
||||||
|
if usage.exhausted:
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_session_done,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"status": "quota_exceeded",
|
||||||
|
}))
|
||||||
|
_index_sessions.pop(session_id, None)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: index_session quota_exceeded user=%s session=%s",
|
||||||
|
user_id, session_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# After processing the batch, emit progress.
|
||||||
|
processed = session["processed"]
|
||||||
|
total = session["total"]
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_session_progress,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"processed": processed,
|
||||||
|
"total": total,
|
||||||
|
}))
|
||||||
|
|
||||||
|
if processed >= total:
|
||||||
|
await websocket.send_text(json.dumps({
|
||||||
|
"type": WsFrameType.index_session_done,
|
||||||
|
"sessionId": session_id,
|
||||||
|
"status": "completed",
|
||||||
|
}))
|
||||||
|
_index_sessions.pop(session_id, None)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: index_session_done completed user=%s session=%s processed=%d",
|
||||||
|
user_id, session_id, processed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||||
|
|||||||
@@ -89,6 +89,13 @@ class WsFrameType(str, Enum):
|
|||||||
brief_request = "brief_request"
|
brief_request = "brief_request"
|
||||||
# ── v6 task brief frame types ─────────────────────────────────────
|
# ── v6 task brief frame types ─────────────────────────────────────
|
||||||
task_brief_request = "task_brief_request"
|
task_brief_request = "task_brief_request"
|
||||||
|
# ── v7 folder index frame types ───────────────────────────────────
|
||||||
|
index_session_start = "index_session_start"
|
||||||
|
index_file_batch = "index_file_batch"
|
||||||
|
index_session_cancel = "index_session_cancel"
|
||||||
|
index_file_result = "index_file_result"
|
||||||
|
index_session_progress = "index_session_progress"
|
||||||
|
index_session_done = "index_session_done"
|
||||||
|
|
||||||
|
|
||||||
class WsToolCall(BaseModel):
|
class WsToolCall(BaseModel):
|
||||||
|
|||||||
196
tests/test_ws_index_session.py
Normal file
196
tests/test_ws_index_session.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
"""Tests for WS folder index_session handlers (Task 9).
|
||||||
|
|
||||||
|
Tests the three handler functions directly with a minimal fake WebSocket so
|
||||||
|
no real WS connection or LLM call is made.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.api.routes.device_ws import (
|
||||||
|
_handle_index_session_start,
|
||||||
|
_handle_index_file_batch,
|
||||||
|
_handle_index_session_cancel,
|
||||||
|
_index_sessions,
|
||||||
|
)
|
||||||
|
from app.billing.quota import add_token_usage
|
||||||
|
from app.core.folder_indexer import IndexResult
|
||||||
|
from app.models import MonthlyTokenUsage
|
||||||
|
from app.schemas import WsFrameType
|
||||||
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
USER_ID = TEST_USER_IDS["free"]
|
||||||
|
POWER_USER_ID = TEST_USER_IDS["power"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fake WebSocket ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _FakeWebSocket:
|
||||||
|
"""Minimal WebSocket stand-in that records send_text calls."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.sent: list[dict] = []
|
||||||
|
|
||||||
|
async def send_text(self, text: str) -> None:
|
||||||
|
self.sent.append(json.loads(text))
|
||||||
|
|
||||||
|
def sent_types(self) -> list[str]:
|
||||||
|
return [f["type"] for f in self.sent]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_session_id() -> str:
|
||||||
|
import uuid
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_summarize_text_factory(summary: str = "A test summary.", tokens: int = 100):
|
||||||
|
"""Return an AsyncMock that resolves to a fixed IndexResult."""
|
||||||
|
async def _fake(**kwargs) -> IndexResult:
|
||||||
|
return IndexResult(summary=summary, tokens_used=tokens)
|
||||||
|
return _fake
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(autouse=True)
|
||||||
|
async def _clean_sessions():
|
||||||
|
"""Ensure _index_sessions is empty before and after each test."""
|
||||||
|
_index_sessions.clear()
|
||||||
|
yield
|
||||||
|
_index_sessions.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def test_index_session_happy_path(db_session):
|
||||||
|
"""start + batch of 2 text files → 2 index_file_result + 1 progress + 1 done(completed)."""
|
||||||
|
ws = _FakeWebSocket()
|
||||||
|
session_id = _make_session_id()
|
||||||
|
|
||||||
|
# Register session.
|
||||||
|
await _handle_index_session_start(ws, USER_ID, {
|
||||||
|
"sessionId": session_id,
|
||||||
|
"projectId": "proj-1",
|
||||||
|
"totalFiles": 2,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Verify session was registered.
|
||||||
|
assert session_id in _index_sessions
|
||||||
|
assert _index_sessions[session_id]["total"] == 2
|
||||||
|
assert _index_sessions[session_id]["processed"] == 0
|
||||||
|
# No response frames expected for session_start.
|
||||||
|
assert ws.sent == []
|
||||||
|
|
||||||
|
# Send batch of 2 text files — patch summarize_text so no LLM call needed.
|
||||||
|
with patch(
|
||||||
|
"app.api.routes.device_ws._handle_index_file_batch.__globals__",
|
||||||
|
# We patch the module-level function in folder_indexer instead:
|
||||||
|
) if False else patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory()):
|
||||||
|
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
|
||||||
|
# Wire db_session into the context manager.
|
||||||
|
mock_cm = AsyncMock()
|
||||||
|
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
|
||||||
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_async_session.return_value = mock_cm
|
||||||
|
|
||||||
|
await _handle_index_file_batch(ws, USER_ID, {
|
||||||
|
"sessionId": session_id,
|
||||||
|
"files": [
|
||||||
|
{"relPath": "README.md", "kind": "text", "content": "hello", "ext": ".md"},
|
||||||
|
{"relPath": "notes.txt", "kind": "text", "content": "world", "ext": ".txt"},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
|
||||||
|
types = ws.sent_types()
|
||||||
|
# Expect 2 file results + 1 progress + 1 done(completed).
|
||||||
|
assert types.count(WsFrameType.index_file_result) == 2
|
||||||
|
assert types.count(WsFrameType.index_session_progress) == 1
|
||||||
|
assert types.count(WsFrameType.index_session_done) == 1
|
||||||
|
|
||||||
|
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||||
|
assert done_frame["status"] == "completed"
|
||||||
|
|
||||||
|
progress_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_progress)
|
||||||
|
assert progress_frame["processed"] == 2
|
||||||
|
assert progress_frame["total"] == 2
|
||||||
|
|
||||||
|
# Verify session cleaned up.
|
||||||
|
assert session_id not in _index_sessions
|
||||||
|
|
||||||
|
|
||||||
|
async def test_index_session_cancel(db_session):
|
||||||
|
"""start then cancel → index_session_done(cancelled)."""
|
||||||
|
ws = _FakeWebSocket()
|
||||||
|
session_id = _make_session_id()
|
||||||
|
|
||||||
|
await _handle_index_session_start(ws, USER_ID, {
|
||||||
|
"sessionId": session_id,
|
||||||
|
"totalFiles": 5,
|
||||||
|
})
|
||||||
|
assert session_id in _index_sessions
|
||||||
|
|
||||||
|
await _handle_index_session_cancel(ws, {"sessionId": session_id})
|
||||||
|
|
||||||
|
types = ws.sent_types()
|
||||||
|
assert WsFrameType.index_session_done in types
|
||||||
|
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||||
|
assert done_frame["status"] == "cancelled"
|
||||||
|
|
||||||
|
# Session should be cleaned up.
|
||||||
|
assert session_id not in _index_sessions
|
||||||
|
|
||||||
|
|
||||||
|
async def test_index_session_quota_exceeded(db_session):
|
||||||
|
"""Pre-fill usage to cap → batch one file → index_session_done(quota_exceeded)."""
|
||||||
|
ws = _FakeWebSocket()
|
||||||
|
session_id = _make_session_id()
|
||||||
|
|
||||||
|
# Pre-fill monthly token usage to the free-tier cap (100_000).
|
||||||
|
ym = datetime.now(timezone.utc).strftime("%Y-%m")
|
||||||
|
db_session.add(MonthlyTokenUsage(
|
||||||
|
user_id=USER_ID,
|
||||||
|
year_month=ym,
|
||||||
|
feature="folder_index",
|
||||||
|
tokens_used=100_000, # free tier cap exactly
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
await _handle_index_session_start(ws, USER_ID, {
|
||||||
|
"sessionId": session_id,
|
||||||
|
"totalFiles": 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
with patch("app.core.folder_indexer.summarize_text", side_effect=_fake_summarize_text_factory(tokens=1)):
|
||||||
|
with patch("app.api.routes.device_ws.async_session") as mock_async_session:
|
||||||
|
mock_cm = AsyncMock()
|
||||||
|
mock_cm.__aenter__ = AsyncMock(return_value=db_session)
|
||||||
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_async_session.return_value = mock_cm
|
||||||
|
|
||||||
|
await _handle_index_file_batch(ws, USER_ID, {
|
||||||
|
"sessionId": session_id,
|
||||||
|
"files": [
|
||||||
|
{"relPath": "file.md", "kind": "text", "content": "content", "ext": ".md"},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
|
||||||
|
types = ws.sent_types()
|
||||||
|
# Should have 1 file result (success) then done(quota_exceeded).
|
||||||
|
assert WsFrameType.index_file_result in types
|
||||||
|
assert WsFrameType.index_session_done in types
|
||||||
|
|
||||||
|
done_frame = next(f for f in ws.sent if f["type"] == WsFrameType.index_session_done)
|
||||||
|
assert done_frame["status"] == "quota_exceeded"
|
||||||
|
|
||||||
|
# Session should be cleaned up.
|
||||||
|
assert session_id not in _index_sessions
|
||||||
Reference in New Issue
Block a user