"""Tests for Step 7 — MemoryMiddleware. Coverage: 1. enrich_context returns core prefs + associative + episodic + proactive 2. store_episode creates an encrypted row decryptable with the user's key 3. update_core upserts correctly 4. User with no encryption_key returns empty context (no crash) 5. End-to-end: home_request WS frame results in an episodic row being stored """ from __future__ import annotations import json import uuid from unittest.mock import AsyncMock, MagicMock, patch import pytest import pytest_asyncio from cryptography.fernet import Fernet from sqlalchemy import select from app.core.embeddings import embed_text from app.core.memory_middleware import MemoryMiddleware from app.db import get_session from app.main import app from app.models import ( MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive, User, ) from tests.conftest import TEST_USER_IDS, make_jwt USER_ID = TEST_USER_IDS["power"] _FERNET_KEY = Fernet.generate_key().decode() # ── DB override ─────────────────────────────────────────────────────────────── @pytest.fixture(autouse=True) def _override_db(db_session): async def _gen(): yield db_session app.dependency_overrides[get_session] = _gen yield app.dependency_overrides.pop(get_session, None) # ── Fixtures ────────────────────────────────────────────────────────────────── @pytest_asyncio.fixture async def user_with_key(db_session): """Set encryption_key on the seeded power user.""" result = await db_session.execute(select(User).where(User.id == USER_ID)) user = result.scalar_one() user.encryption_key = _FERNET_KEY await db_session.commit() return user def _fernet(): return Fernet(_FERNET_KEY.encode()) def _enc(plaintext: str) -> str: return _fernet().encrypt(plaintext.encode()).decode() def _dec(ciphertext: str) -> str: return _fernet().decrypt(ciphertext.encode()).decode() # ── enrich_context ──────────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_enrich_context_returns_core_memory(db_session, user_with_key): # Seed a core memory row db_session.add(MemoryCore( id=str(uuid.uuid4()), user_id=USER_ID, key="timezone", value_encrypted=_enc("UTC"), )) await db_session.commit() middleware = MemoryMiddleware(db_session) ctx = await middleware.enrich_context(USER_ID, "What are my tasks?") assert "core_memory" in ctx assert ctx["core_memory"]["timezone"] == "UTC" @pytest.mark.asyncio async def test_enrich_context_returns_episodic_memory(db_session, user_with_key): session_id = str(uuid.uuid4()) db_session.add(MemoryEpisodic( id=str(uuid.uuid4()), user_id=USER_ID, summary_encrypted=_enc("User asked about Q1 tasks"), session_id=session_id, )) await db_session.commit() middleware = MemoryMiddleware(db_session) ctx = await middleware.enrich_context(USER_ID, "any message") assert "episodic_memory" in ctx assert any("Q1 tasks" in s for s in ctx["episodic_memory"]) @pytest.mark.asyncio async def test_enrich_context_filters_episodic_by_session_id(db_session, user_with_key): target_session = str(uuid.uuid4()) other_session = str(uuid.uuid4()) db_session.add(MemoryEpisodic( id=str(uuid.uuid4()), user_id=USER_ID, summary_encrypted=_enc("Target session memory"), session_id=target_session, )) db_session.add(MemoryEpisodic( id=str(uuid.uuid4()), user_id=USER_ID, summary_encrypted=_enc("Other session memory"), session_id=other_session, )) await db_session.commit() middleware = MemoryMiddleware(db_session) ctx = await middleware.enrich_context(USER_ID, "any message", session_id=target_session) episodic = ctx.get("episodic_memory", []) assert any("Target session" in s for s in episodic) assert not any("Other session" in s for s in episodic) @pytest.mark.asyncio async def test_enrich_context_returns_proactive_hints(db_session, user_with_key): # Add one pattern above threshold and one below db_session.add(MemoryProactive( id=str(uuid.uuid4()), user_id=USER_ID, pattern_encrypted=_enc("User prefers short summaries"), confidence=0.9, source="inferred", )) db_session.add(MemoryProactive( id=str(uuid.uuid4()), user_id=USER_ID, pattern_encrypted=_enc("User likes dark mode"), confidence=0.1, source="inferred", )) await db_session.commit() middleware = MemoryMiddleware(db_session) ctx = await middleware.enrich_context(USER_ID, "any message") assert "proactive_hints" in ctx hints = ctx["proactive_hints"] assert any("short summaries" in h for h in hints) assert not any("dark mode" in h for h in hints) @pytest.mark.asyncio async def test_enrich_context_returns_associative_memory(db_session, user_with_key): db_session.add(MemoryAssociative( id=str(uuid.uuid4()), user_id=USER_ID, content_encrypted=_enc("Related memory about meetings"), embedding=None, entity_type="note", )) await db_session.commit() middleware = MemoryMiddleware(db_session) ctx = await middleware.enrich_context(USER_ID, "meetings") assert "associative_memory" in ctx assert any("meetings" in m for m in ctx["associative_memory"]) @pytest.mark.asyncio async def test_enrich_context_empty_for_user_without_key(db_session): """User with no encryption_key → empty context, no crash.""" result = await db_session.execute(select(User).where(User.id == USER_ID)) user = result.scalar_one() user.encryption_key = None await db_session.commit() middleware = MemoryMiddleware(db_session) ctx = await middleware.enrich_context(USER_ID, "hello") assert ctx == {} # ── store_episode ───────────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_store_episode_creates_encrypted_row(db_session, user_with_key): session_id = str(uuid.uuid4()) middleware = MemoryMiddleware(db_session) await middleware.store_episode(USER_ID, session_id, "hello", "world") result = await db_session.execute( select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id) ) row = result.scalar_one() plaintext = _dec(row.summary_encrypted) assert "hello" in plaintext assert "world" in plaintext @pytest.mark.asyncio async def test_store_episode_decryptable(db_session, user_with_key): session_id = str(uuid.uuid4()) middleware = MemoryMiddleware(db_session) await middleware.store_episode(USER_ID, session_id, "msg", "resp") result = await db_session.execute( select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id) ) row = result.scalar_one() # Decrypt using the same key — must not raise decrypted = _dec(row.summary_encrypted) assert len(decrypted) > 0 # ── update_core ─────────────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_update_core_insert(db_session, user_with_key): middleware = MemoryMiddleware(db_session) await middleware.update_core(USER_ID, "lang", "en") result = await db_session.execute( select(MemoryCore).where(MemoryCore.user_id == USER_ID, MemoryCore.key == "lang") ) row = result.scalar_one() assert _dec(row.value_encrypted) == "en" @pytest.mark.asyncio async def test_update_core_upsert(db_session, user_with_key): middleware = MemoryMiddleware(db_session) await middleware.update_core(USER_ID, "lang", "en") await middleware.update_core(USER_ID, "lang", "fr") result = await db_session.execute( select(MemoryCore).where(MemoryCore.user_id == USER_ID, MemoryCore.key == "lang") ) rows = result.scalars().all() assert len(rows) == 1 assert _dec(rows[0].value_encrypted) == "fr" @pytest.mark.asyncio async def test_core_block_edit_ops(db_session, user_with_key): middleware = MemoryMiddleware(db_session) await middleware.update_core(USER_ID, "human", "Name: Roberto") await middleware.append_core(USER_ID, "human", "Timezone: Europe/Rome") replaced = await middleware.replace_core(USER_ID, "human", "Roberto", "Robert") blocks = await middleware.list_core_blocks(USER_ID) human = next(b for b in blocks if b["label"] == "human") assert replaced is True assert "Name: Robert" in human["value"] assert "Timezone: Europe/Rome" in human["value"] deleted = await middleware.delete_core(USER_ID, "human") assert deleted is True assert await middleware.get_core_block(USER_ID, "human") is None @pytest.mark.asyncio async def test_archival_and_recall_search_helpers(db_session, user_with_key): middleware = MemoryMiddleware(db_session) await middleware.insert_archival(USER_ID, "Project whitelist has release risk", source="assistant") await middleware.store_episode(USER_ID, str(uuid.uuid4()), "How is whitelist?", "Whitelist is delayed") arch = await middleware.search_archival(USER_ID, "whitelist", top_k=3) rec = await middleware.search_recall(USER_ID, "delayed", top_k=3) assert any("whitelist" in item.lower() for item in arch) assert any("delayed" in item.lower() for item in rec) # ── End-to-end WS: memory middleware is called during home_request ──────────── def test_home_request_calls_memory_middleware(client): """home_request triggers enrich_context before and store_episode after the LLM.""" enrich_calls: list[tuple] = [] store_calls: list[tuple] = [] class _MockMiddleware: def __init__(self, db): pass async def enrich_context(self, user_id, message, **kwargs): enrich_calls.append((user_id, message)) return {"core_memory": {"tz": "UTC"}} async def store_episode(self, user_id, session_id, message, response, **kwargs): store_calls.append((user_id, session_id, message, response)) token = make_jwt("power", user_id=USER_ID) session_id = str(uuid.uuid4()) async def _mock_stream(user_id, message, context): # Verify memory context was injected assert context.get("core_memory") == {"tz": "UTC"} yield "token", "Done" with ( patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware), patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_stream), ): with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws: ws.send_text(json.dumps({ "type": "device_hello", "device_id": "dev-mem", "scout_ids": [] })) ws.send_text(json.dumps({ "type": "home_request", "request_id": "r-mem", "session_id": session_id, "message": "Show tasks", })) for _ in range(20): raw = ws.receive_text() frame = json.loads(raw) if frame.get("type") == "stream_end": break assert len(enrich_calls) == 1 assert enrich_calls[0] == (USER_ID, "Show tasks") assert len(store_calls) == 1 stored_session_id, stored_message = store_calls[0][1], store_calls[0][2] assert stored_session_id == session_id assert stored_message == "Show tasks" # ── embed_text ───────────────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_embed_text_returns_1536_floats(): """embed_text returns a 1536-dim float list when OpenAI responds successfully.""" fake_embedding = [0.1] * 1536 mock_response = MagicMock() mock_response.data = [MagicMock(embedding=fake_embedding)] mock_client = MagicMock() mock_client.embeddings.create = AsyncMock(return_value=mock_response) with patch("app.core.embeddings.AsyncOpenAI", return_value=mock_client): result = await embed_text("test text") assert result is not None assert len(result) == 1536 assert all(isinstance(x, float) for x in result) @pytest.mark.asyncio async def test_embed_text_returns_none_on_failure(): """embed_text returns None when OpenAI raises; must not propagate the exception.""" with patch("app.core.embeddings.AsyncOpenAI", side_effect=Exception("no key")): result = await embed_text("test text") assert result is None