WsDeviceHello.agent_ids → scout_ids in Pydantic schema, device_ws.py handler, and all test fixtures (test_device_ws, test_ws_unified, test_memory_middleware). Also fixes stale CloudAgentConfig reference in gmail.py docstring. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
375 lines
13 KiB
Python
375 lines
13 KiB
Python
"""Tests for Step 7 — MemoryMiddleware.
|
|
|
|
Coverage:
|
|
1. enrich_context returns core prefs + associative + episodic + proactive
|
|
2. store_episode creates an encrypted row decryptable with the user's key
|
|
3. update_core upserts correctly
|
|
4. User with no encryption_key returns empty context (no crash)
|
|
5. End-to-end: home_request WS frame results in an episodic row being stored
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import uuid
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from cryptography.fernet import Fernet
|
|
from sqlalchemy import select
|
|
|
|
from app.core.embeddings import embed_text
|
|
from app.core.memory_middleware import MemoryMiddleware
|
|
from app.db import get_session
|
|
from app.main import app
|
|
from app.models import (
|
|
MemoryAssociative,
|
|
MemoryCore,
|
|
MemoryEpisodic,
|
|
MemoryProactive,
|
|
User,
|
|
)
|
|
from tests.conftest import TEST_USER_IDS, make_jwt
|
|
|
|
|
|
USER_ID = TEST_USER_IDS["power"]
|
|
_FERNET_KEY = Fernet.generate_key().decode()
|
|
|
|
|
|
# ── DB override ───────────────────────────────────────────────────────────────
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _override_db(db_session):
|
|
async def _gen():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_session] = _gen
|
|
yield
|
|
app.dependency_overrides.pop(get_session, None)
|
|
|
|
|
|
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
|
|
|
@pytest_asyncio.fixture
|
|
async def user_with_key(db_session):
|
|
"""Set encryption_key on the seeded power user."""
|
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
|
user = result.scalar_one()
|
|
user.encryption_key = _FERNET_KEY
|
|
await db_session.commit()
|
|
return user
|
|
|
|
|
|
def _fernet():
|
|
return Fernet(_FERNET_KEY.encode())
|
|
|
|
|
|
def _enc(plaintext: str) -> str:
|
|
return _fernet().encrypt(plaintext.encode()).decode()
|
|
|
|
|
|
def _dec(ciphertext: str) -> str:
|
|
return _fernet().decrypt(ciphertext.encode()).decode()
|
|
|
|
|
|
# ── enrich_context ────────────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enrich_context_returns_core_memory(db_session, user_with_key):
|
|
# Seed a core memory row
|
|
db_session.add(MemoryCore(
|
|
id=str(uuid.uuid4()),
|
|
user_id=USER_ID,
|
|
key="timezone",
|
|
value_encrypted=_enc("UTC"),
|
|
))
|
|
await db_session.commit()
|
|
|
|
middleware = MemoryMiddleware(db_session)
|
|
ctx = await middleware.enrich_context(USER_ID, "What are my tasks?")
|
|
|
|
assert "core_memory" in ctx
|
|
assert ctx["core_memory"]["timezone"] == "UTC"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enrich_context_returns_episodic_memory(db_session, user_with_key):
|
|
session_id = str(uuid.uuid4())
|
|
db_session.add(MemoryEpisodic(
|
|
id=str(uuid.uuid4()),
|
|
user_id=USER_ID,
|
|
summary_encrypted=_enc("User asked about Q1 tasks"),
|
|
session_id=session_id,
|
|
))
|
|
await db_session.commit()
|
|
|
|
middleware = MemoryMiddleware(db_session)
|
|
ctx = await middleware.enrich_context(USER_ID, "any message")
|
|
|
|
assert "episodic_memory" in ctx
|
|
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enrich_context_filters_episodic_by_session_id(db_session, user_with_key):
|
|
target_session = str(uuid.uuid4())
|
|
other_session = str(uuid.uuid4())
|
|
db_session.add(MemoryEpisodic(
|
|
id=str(uuid.uuid4()),
|
|
user_id=USER_ID,
|
|
summary_encrypted=_enc("Target session memory"),
|
|
session_id=target_session,
|
|
))
|
|
db_session.add(MemoryEpisodic(
|
|
id=str(uuid.uuid4()),
|
|
user_id=USER_ID,
|
|
summary_encrypted=_enc("Other session memory"),
|
|
session_id=other_session,
|
|
))
|
|
await db_session.commit()
|
|
|
|
middleware = MemoryMiddleware(db_session)
|
|
ctx = await middleware.enrich_context(USER_ID, "any message", session_id=target_session)
|
|
|
|
episodic = ctx.get("episodic_memory", [])
|
|
assert any("Target session" in s for s in episodic)
|
|
assert not any("Other session" in s for s in episodic)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
|
# Add one pattern above threshold and one below
|
|
db_session.add(MemoryProactive(
|
|
id=str(uuid.uuid4()),
|
|
user_id=USER_ID,
|
|
pattern_encrypted=_enc("User prefers short summaries"),
|
|
confidence=0.9,
|
|
source="inferred",
|
|
))
|
|
db_session.add(MemoryProactive(
|
|
id=str(uuid.uuid4()),
|
|
user_id=USER_ID,
|
|
pattern_encrypted=_enc("User likes dark mode"),
|
|
confidence=0.1,
|
|
source="inferred",
|
|
))
|
|
await db_session.commit()
|
|
|
|
middleware = MemoryMiddleware(db_session)
|
|
ctx = await middleware.enrich_context(USER_ID, "any message")
|
|
|
|
assert "proactive_hints" in ctx
|
|
hints = ctx["proactive_hints"]
|
|
assert any("short summaries" in h for h in hints)
|
|
assert not any("dark mode" in h for h in hints)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enrich_context_returns_associative_memory(db_session, user_with_key):
|
|
db_session.add(MemoryAssociative(
|
|
id=str(uuid.uuid4()),
|
|
user_id=USER_ID,
|
|
content_encrypted=_enc("Related memory about meetings"),
|
|
embedding=None,
|
|
entity_type="note",
|
|
))
|
|
await db_session.commit()
|
|
|
|
middleware = MemoryMiddleware(db_session)
|
|
ctx = await middleware.enrich_context(USER_ID, "meetings")
|
|
|
|
assert "associative_memory" in ctx
|
|
assert any("meetings" in m for m in ctx["associative_memory"])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_enrich_context_empty_for_user_without_key(db_session):
|
|
"""User with no encryption_key → empty context, no crash."""
|
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
|
user = result.scalar_one()
|
|
user.encryption_key = None
|
|
await db_session.commit()
|
|
|
|
middleware = MemoryMiddleware(db_session)
|
|
ctx = await middleware.enrich_context(USER_ID, "hello")
|
|
assert ctx == {}
|
|
|
|
|
|
# ── store_episode ─────────────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_store_episode_creates_encrypted_row(db_session, user_with_key):
|
|
session_id = str(uuid.uuid4())
|
|
middleware = MemoryMiddleware(db_session)
|
|
await middleware.store_episode(USER_ID, session_id, "hello", "world")
|
|
|
|
result = await db_session.execute(
|
|
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
|
|
)
|
|
row = result.scalar_one()
|
|
plaintext = _dec(row.summary_encrypted)
|
|
assert "hello" in plaintext
|
|
assert "world" in plaintext
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_store_episode_decryptable(db_session, user_with_key):
|
|
session_id = str(uuid.uuid4())
|
|
middleware = MemoryMiddleware(db_session)
|
|
await middleware.store_episode(USER_ID, session_id, "msg", "resp")
|
|
|
|
result = await db_session.execute(
|
|
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
|
|
)
|
|
row = result.scalar_one()
|
|
# Decrypt using the same key — must not raise
|
|
decrypted = _dec(row.summary_encrypted)
|
|
assert len(decrypted) > 0
|
|
|
|
|
|
# ── update_core ───────────────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_core_insert(db_session, user_with_key):
|
|
middleware = MemoryMiddleware(db_session)
|
|
await middleware.update_core(USER_ID, "lang", "en")
|
|
|
|
result = await db_session.execute(
|
|
select(MemoryCore).where(MemoryCore.user_id == USER_ID, MemoryCore.key == "lang")
|
|
)
|
|
row = result.scalar_one()
|
|
assert _dec(row.value_encrypted) == "en"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_core_upsert(db_session, user_with_key):
|
|
middleware = MemoryMiddleware(db_session)
|
|
await middleware.update_core(USER_ID, "lang", "en")
|
|
await middleware.update_core(USER_ID, "lang", "fr")
|
|
|
|
result = await db_session.execute(
|
|
select(MemoryCore).where(MemoryCore.user_id == USER_ID, MemoryCore.key == "lang")
|
|
)
|
|
rows = result.scalars().all()
|
|
assert len(rows) == 1
|
|
assert _dec(rows[0].value_encrypted) == "fr"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_core_block_edit_ops(db_session, user_with_key):
|
|
middleware = MemoryMiddleware(db_session)
|
|
|
|
await middleware.update_core(USER_ID, "human", "Name: Roberto")
|
|
await middleware.append_core(USER_ID, "human", "Timezone: Europe/Rome")
|
|
replaced = await middleware.replace_core(USER_ID, "human", "Roberto", "Robert")
|
|
|
|
blocks = await middleware.list_core_blocks(USER_ID)
|
|
human = next(b for b in blocks if b["label"] == "human")
|
|
|
|
assert replaced is True
|
|
assert "Name: Robert" in human["value"]
|
|
assert "Timezone: Europe/Rome" in human["value"]
|
|
|
|
deleted = await middleware.delete_core(USER_ID, "human")
|
|
assert deleted is True
|
|
assert await middleware.get_core_block(USER_ID, "human") is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_archival_and_recall_search_helpers(db_session, user_with_key):
|
|
middleware = MemoryMiddleware(db_session)
|
|
|
|
await middleware.insert_archival(USER_ID, "Project whitelist has release risk", source="assistant")
|
|
await middleware.store_episode(USER_ID, str(uuid.uuid4()), "How is whitelist?", "Whitelist is delayed")
|
|
|
|
arch = await middleware.search_archival(USER_ID, "whitelist", top_k=3)
|
|
rec = await middleware.search_recall(USER_ID, "delayed", top_k=3)
|
|
|
|
assert any("whitelist" in item.lower() for item in arch)
|
|
assert any("delayed" in item.lower() for item in rec)
|
|
|
|
|
|
# ── End-to-end WS: memory middleware is called during home_request ────────────
|
|
|
|
def test_home_request_calls_memory_middleware(client):
|
|
"""home_request triggers enrich_context before and store_episode after the LLM."""
|
|
enrich_calls: list[tuple] = []
|
|
store_calls: list[tuple] = []
|
|
|
|
class _MockMiddleware:
|
|
def __init__(self, db):
|
|
pass
|
|
|
|
async def enrich_context(self, user_id, message, **kwargs):
|
|
enrich_calls.append((user_id, message))
|
|
return {"core_memory": {"tz": "UTC"}}
|
|
|
|
async def store_episode(self, user_id, session_id, message, response, **kwargs):
|
|
store_calls.append((user_id, session_id, message, response))
|
|
|
|
token = make_jwt("power", user_id=USER_ID)
|
|
session_id = str(uuid.uuid4())
|
|
|
|
async def _mock_stream(user_id, message, context):
|
|
# Verify memory context was injected
|
|
assert context.get("core_memory") == {"tz": "UTC"}
|
|
yield "token", "Done"
|
|
|
|
with (
|
|
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
|
|
patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_stream),
|
|
):
|
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
|
ws.send_text(json.dumps({
|
|
"type": "device_hello", "device_id": "dev-mem", "scout_ids": []
|
|
}))
|
|
ws.send_text(json.dumps({
|
|
"type": "home_request",
|
|
"request_id": "r-mem",
|
|
"session_id": session_id,
|
|
"message": "Show tasks",
|
|
}))
|
|
for _ in range(20):
|
|
raw = ws.receive_text()
|
|
frame = json.loads(raw)
|
|
if frame.get("type") == "stream_end":
|
|
break
|
|
|
|
assert len(enrich_calls) == 1
|
|
assert enrich_calls[0] == (USER_ID, "Show tasks")
|
|
assert len(store_calls) == 1
|
|
stored_session_id, stored_message = store_calls[0][1], store_calls[0][2]
|
|
assert stored_session_id == session_id
|
|
assert stored_message == "Show tasks"
|
|
|
|
|
|
# ── embed_text ─────────────────────────────────────────────────────────────────
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_embed_text_returns_1536_floats():
|
|
"""embed_text returns a 1536-dim float list when OpenAI responds successfully."""
|
|
fake_embedding = [0.1] * 1536
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.data = [MagicMock(embedding=fake_embedding)]
|
|
|
|
mock_client = MagicMock()
|
|
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("app.core.embeddings.AsyncOpenAI", return_value=mock_client):
|
|
result = await embed_text("test text")
|
|
|
|
assert result is not None
|
|
assert len(result) == 1536
|
|
assert all(isinstance(x, float) for x in result)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_embed_text_returns_none_on_failure():
|
|
"""embed_text returns None when OpenAI raises; must not propagate the exception."""
|
|
with patch("app.core.embeddings.AsyncOpenAI", side_effect=Exception("no key")):
|
|
result = await embed_text("test text")
|
|
|
|
assert result is None
|