diff --git a/app/api/routes/device_ws.py b/app/api/routes/device_ws.py index 0c70cd4..86cc728 100644 --- a/app/api/routes/device_ws.py +++ b/app/api/routes/device_ws.py @@ -230,7 +230,12 @@ async def _handle_home_request( # ── Memory: enrich context before LLM call ──────────────────────── async with async_session() as db: memory = MemoryMiddleware(db) - memory_context = await memory.enrich_context(user_id, message, trace_id=request_id) + memory_context = await memory.enrich_context( + user_id, + message, + trace_id=request_id, + session_id=session_id, + ) context: dict = { "conversation_history": frame.get("conversation_history", []), @@ -294,7 +299,12 @@ async def _handle_floating_request( # ── Memory: enrich context before LLM call ──────────────────────── async with async_session() as db: memory = MemoryMiddleware(db) - memory_context = await memory.enrich_context(user_id, message, trace_id=request_id) + memory_context = await memory.enrich_context( + user_id, + message, + trace_id=request_id, + session_id=session_id, + ) context: dict = { "scope": scope, diff --git a/app/core/memory_middleware.py b/app/core/memory_middleware.py index 0a55199..e1b2f64 100644 --- a/app/core/memory_middleware.py +++ b/app/core/memory_middleware.py @@ -50,7 +50,13 @@ class MemoryMiddleware: # ── Public API ──────────────────────────────────────────────────────────── - async def enrich_context(self, user_id: str, message: str, trace_id: str | None = None) -> dict[str, Any]: + async def enrich_context( + self, + user_id: str, + message: str, + trace_id: str | None = None, + session_id: str | None = None, + ) -> dict[str, Any]: """Build memory context dict to inject into the orchestrator before LLM call. Returns a dict with keys: @@ -65,7 +71,7 @@ class MemoryMiddleware: core = await self._load_core(user_id, fernet) associative = await self._load_associative(user_id, message, fernet) - episodic = await self._load_episodic(user_id, fernet) + episodic = await self._load_episodic(user_id, fernet, session_id=session_id) proactive = await self._load_proactive(user_id, fernet) user_dbg = await self._get_user_debug(user_id) @@ -380,10 +386,17 @@ class MemoryMiddleware: out.append(plaintext) return out - async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]: + async def _load_episodic( + self, + user_id: str, + fernet: Fernet, + session_id: str | None = None, + ) -> list[str]: + query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id) + if session_id: + query = query.where(MemoryEpisodic.session_id == session_id) result = await self._db.execute( - select(MemoryEpisodic) - .where(MemoryEpisodic.user_id == user_id) + query .order_by(MemoryEpisodic.created_at.desc()) .limit(_EPISODIC_RECENT_N) ) diff --git a/tests/test_memory_middleware.py b/tests/test_memory_middleware.py index c978c1a..1ba6f7f 100644 --- a/tests/test_memory_middleware.py +++ b/tests/test_memory_middleware.py @@ -110,6 +110,32 @@ async def test_enrich_context_returns_episodic_memory(db_session, user_with_key) assert any("Q1 tasks" in s for s in ctx["episodic_memory"]) +@pytest.mark.asyncio +async def test_enrich_context_filters_episodic_by_session_id(db_session, user_with_key): + target_session = str(uuid.uuid4()) + other_session = str(uuid.uuid4()) + db_session.add(MemoryEpisodic( + id=str(uuid.uuid4()), + user_id=USER_ID, + summary_encrypted=_enc("Target session memory"), + session_id=target_session, + )) + db_session.add(MemoryEpisodic( + id=str(uuid.uuid4()), + user_id=USER_ID, + summary_encrypted=_enc("Other session memory"), + session_id=other_session, + )) + await db_session.commit() + + middleware = MemoryMiddleware(db_session) + ctx = await middleware.enrich_context(USER_ID, "any message", session_id=target_session) + + episodic = ctx.get("episodic_memory", []) + assert any("Target session" in s for s in episodic) + assert not any("Other session" in s for s in episodic) + + @pytest.mark.asyncio async def test_enrich_context_returns_proactive_hints(db_session, user_with_key): # Add one pattern above threshold and one below @@ -274,11 +300,11 @@ def test_home_request_calls_memory_middleware(client): def __init__(self, db): pass - async def enrich_context(self, user_id, message): + async def enrich_context(self, user_id, message, **kwargs): enrich_calls.append((user_id, message)) return {"core_memory": {"tz": "UTC"}} - async def store_episode(self, user_id, session_id, message, response): + async def store_episode(self, user_id, session_id, message, response, **kwargs): store_calls.append((user_id, session_id, message, response)) token = make_jwt("power", user_id=USER_ID)