scope episodic memory enrichment by session_id

This commit is contained in:
2026-03-16 00:33:11 +01:00
parent fae9efee0d
commit 02a9684cd6
3 changed files with 58 additions and 9 deletions

View File

@@ -230,7 +230,12 @@ async def _handle_home_request(
# ── Memory: enrich context before LLM call ────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(user_id, message, trace_id=request_id)
memory_context = await memory.enrich_context(
user_id,
message,
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"conversation_history": frame.get("conversation_history", []),
@@ -294,7 +299,12 @@ async def _handle_floating_request(
# ── Memory: enrich context before LLM call ────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(user_id, message, trace_id=request_id)
memory_context = await memory.enrich_context(
user_id,
message,
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"scope": scope,

View File

@@ -50,7 +50,13 @@ class MemoryMiddleware:
# ── Public API ────────────────────────────────────────────────────────────
async def enrich_context(self, user_id: str, message: str, trace_id: str | None = None) -> dict[str, Any]:
async def enrich_context(
self,
user_id: str,
message: str,
trace_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Build memory context dict to inject into the orchestrator before LLM call.
Returns a dict with keys:
@@ -65,7 +71,7 @@ class MemoryMiddleware:
core = await self._load_core(user_id, fernet)
associative = await self._load_associative(user_id, message, fernet)
episodic = await self._load_episodic(user_id, fernet)
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
proactive = await self._load_proactive(user_id, fernet)
user_dbg = await self._get_user_debug(user_id)
@@ -380,10 +386,17 @@ class MemoryMiddleware:
out.append(plaintext)
return out
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
async def _load_episodic(
self,
user_id: str,
fernet: Fernet,
session_id: str | None = None,
) -> list[str]:
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
if session_id:
query = query.where(MemoryEpisodic.session_id == session_id)
result = await self._db.execute(
select(MemoryEpisodic)
.where(MemoryEpisodic.user_id == user_id)
query
.order_by(MemoryEpisodic.created_at.desc())
.limit(_EPISODIC_RECENT_N)
)