scope episodic memory enrichment by session_id
This commit is contained in:
@@ -230,7 +230,12 @@ async def _handle_home_request(
|
|||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(user_id, message, trace_id=request_id)
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id,
|
||||||
|
message,
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
context: dict = {
|
context: dict = {
|
||||||
"conversation_history": frame.get("conversation_history", []),
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
@@ -294,7 +299,12 @@ async def _handle_floating_request(
|
|||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(user_id, message, trace_id=request_id)
|
memory_context = await memory.enrich_context(
|
||||||
|
user_id,
|
||||||
|
message,
|
||||||
|
trace_id=request_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
context: dict = {
|
context: dict = {
|
||||||
"scope": scope,
|
"scope": scope,
|
||||||
|
|||||||
@@ -50,7 +50,13 @@ class MemoryMiddleware:
|
|||||||
|
|
||||||
# ── Public API ────────────────────────────────────────────────────────────
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def enrich_context(self, user_id: str, message: str, trace_id: str | None = None) -> dict[str, Any]:
|
async def enrich_context(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
"""Build memory context dict to inject into the orchestrator before LLM call.
|
"""Build memory context dict to inject into the orchestrator before LLM call.
|
||||||
|
|
||||||
Returns a dict with keys:
|
Returns a dict with keys:
|
||||||
@@ -65,7 +71,7 @@ class MemoryMiddleware:
|
|||||||
|
|
||||||
core = await self._load_core(user_id, fernet)
|
core = await self._load_core(user_id, fernet)
|
||||||
associative = await self._load_associative(user_id, message, fernet)
|
associative = await self._load_associative(user_id, message, fernet)
|
||||||
episodic = await self._load_episodic(user_id, fernet)
|
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||||
proactive = await self._load_proactive(user_id, fernet)
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
|
||||||
user_dbg = await self._get_user_debug(user_id)
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
@@ -380,10 +386,17 @@ class MemoryMiddleware:
|
|||||||
out.append(plaintext)
|
out.append(plaintext)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
|
async def _load_episodic(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
fernet: Fernet,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> list[str]:
|
||||||
|
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
||||||
|
if session_id:
|
||||||
|
query = query.where(MemoryEpisodic.session_id == session_id)
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryEpisodic)
|
query
|
||||||
.where(MemoryEpisodic.user_id == user_id)
|
|
||||||
.order_by(MemoryEpisodic.created_at.desc())
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
.limit(_EPISODIC_RECENT_N)
|
.limit(_EPISODIC_RECENT_N)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -110,6 +110,32 @@ async def test_enrich_context_returns_episodic_memory(db_session, user_with_key)
|
|||||||
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
|
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_filters_episodic_by_session_id(db_session, user_with_key):
|
||||||
|
target_session = str(uuid.uuid4())
|
||||||
|
other_session = str(uuid.uuid4())
|
||||||
|
db_session.add(MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
summary_encrypted=_enc("Target session memory"),
|
||||||
|
session_id=target_session,
|
||||||
|
))
|
||||||
|
db_session.add(MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
summary_encrypted=_enc("Other session memory"),
|
||||||
|
session_id=other_session,
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "any message", session_id=target_session)
|
||||||
|
|
||||||
|
episodic = ctx.get("episodic_memory", [])
|
||||||
|
assert any("Target session" in s for s in episodic)
|
||||||
|
assert not any("Other session" in s for s in episodic)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
||||||
# Add one pattern above threshold and one below
|
# Add one pattern above threshold and one below
|
||||||
@@ -274,11 +300,11 @@ def test_home_request_calls_memory_middleware(client):
|
|||||||
def __init__(self, db):
|
def __init__(self, db):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def enrich_context(self, user_id, message):
|
async def enrich_context(self, user_id, message, **kwargs):
|
||||||
enrich_calls.append((user_id, message))
|
enrich_calls.append((user_id, message))
|
||||||
return {"core_memory": {"tz": "UTC"}}
|
return {"core_memory": {"tz": "UTC"}}
|
||||||
|
|
||||||
async def store_episode(self, user_id, session_id, message, response):
|
async def store_episode(self, user_id, session_id, message, response, **kwargs):
|
||||||
store_calls.append((user_id, session_id, message, response))
|
store_calls.append((user_id, session_id, message, response))
|
||||||
|
|
||||||
token = make_jwt("power", user_id=USER_ID)
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
|||||||
Reference in New Issue
Block a user