429 lines
16 KiB
Python
429 lines
16 KiB
Python
"""Memory Middleware — enrich requests with memory context and store interactions.
|
|
|
|
Four-tier memory model (MemGPT-style):
|
|
core — persistent key/value user preferences, always injected
|
|
associative — semantic similarity search via pgvector (top-k)
|
|
episodic — recent session summaries (last N)
|
|
proactive — behavioral patterns above confidence threshold
|
|
|
|
All memory content is encrypted at rest using the per-user Fernet key
|
|
stored in User.encryption_key. Decryption happens in-memory only.
|
|
|
|
Usage:
|
|
memory = MemoryMiddleware(db_session)
|
|
context = await memory.enrich_context(user_id, message)
|
|
# ... run agent ...
|
|
await memory.store_episode(user_id, session_id, message, response)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import uuid
|
|
from typing import Any
|
|
|
|
from cryptography.fernet import Fernet, InvalidToken
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models import (
|
|
MemoryAssociative,
|
|
MemoryCore,
|
|
MemoryEpisodic,
|
|
MemoryProactive,
|
|
User,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Tuning constants
|
|
_ASSOCIATIVE_TOP_K = 5
|
|
_EPISODIC_RECENT_N = 10
|
|
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
|
|
|
|
|
class MemoryMiddleware:
|
|
"""Enrich orchestrator context with memory and persist interactions after."""
|
|
|
|
def __init__(self, db: AsyncSession) -> None:
|
|
self._db = db
|
|
|
|
# ── Public API ────────────────────────────────────────────────────────────
|
|
|
|
async def enrich_context(self, user_id: str, message: str, trace_id: str | None = None) -> dict[str, Any]:
|
|
"""Build memory context dict to inject into the orchestrator before LLM call.
|
|
|
|
Returns a dict with keys:
|
|
core_memory — {key: plaintext_value, ...}
|
|
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
|
episodic_memory — [plaintext_summary, ...] (most recent N)
|
|
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
|
"""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return {}
|
|
|
|
core = await self._load_core(user_id, fernet)
|
|
associative = await self._load_associative(user_id, message, fernet)
|
|
episodic = await self._load_episodic(user_id, fernet)
|
|
proactive = await self._load_proactive(user_id, fernet)
|
|
|
|
user_dbg = await self._get_user_debug(user_id)
|
|
logger.info(
|
|
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
|
trace_id or "-",
|
|
user_id,
|
|
user_dbg.get("tier") or "-",
|
|
len(core),
|
|
len(associative),
|
|
len(episodic),
|
|
len(proactive),
|
|
)
|
|
|
|
return {
|
|
"core_memory": core,
|
|
"associative_memory": associative,
|
|
"episodic_memory": episodic,
|
|
"proactive_hints": proactive,
|
|
}
|
|
|
|
async def store_episode(
|
|
self,
|
|
user_id: str,
|
|
session_id: str,
|
|
message: str,
|
|
response: str,
|
|
trace_id: str | None = None,
|
|
) -> None:
|
|
"""Summarise and store a completed interaction in episodic memory.
|
|
|
|
The summary is a simple heuristic concatenation (no LLM call) to keep
|
|
latency low. Full LLM summarisation can be added in a later step.
|
|
"""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return
|
|
|
|
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
|
encrypted = _encrypt(fernet, summary)
|
|
|
|
row = MemoryEpisodic(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
summary_encrypted=encrypted,
|
|
session_id=session_id,
|
|
)
|
|
self._db.add(row)
|
|
try:
|
|
await self._db.commit()
|
|
user_dbg = await self._get_user_debug(user_id)
|
|
logger.info(
|
|
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
|
trace_id or "-",
|
|
user_id,
|
|
user_dbg.get("tier") or "-",
|
|
session_id,
|
|
)
|
|
except Exception as exc:
|
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
|
await self._db.rollback()
|
|
|
|
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
|
"""Upsert a core memory key/value for a user."""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return
|
|
|
|
encrypted = _encrypt(fernet, value)
|
|
|
|
result = await self._db.execute(
|
|
select(MemoryCore).where(
|
|
MemoryCore.user_id == user_id,
|
|
MemoryCore.key == key,
|
|
)
|
|
)
|
|
existing = result.scalar_one_or_none()
|
|
if existing is not None:
|
|
existing.value_encrypted = encrypted
|
|
else:
|
|
self._db.add(MemoryCore(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
key=key,
|
|
value_encrypted=encrypted,
|
|
))
|
|
try:
|
|
await self._db.commit()
|
|
user_dbg = await self._get_user_debug(user_id)
|
|
logger.info(
|
|
"memory: update_core trace=%s user=%s tier=%s key=%s",
|
|
trace_id or "-",
|
|
user_id,
|
|
user_dbg.get("tier") or "-",
|
|
key,
|
|
)
|
|
except Exception as exc:
|
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
|
await self._db.rollback()
|
|
|
|
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
|
"""Return core memory as editable blocks (label/value)."""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return []
|
|
|
|
result = await self._db.execute(
|
|
select(MemoryCore)
|
|
.where(MemoryCore.user_id == user_id)
|
|
.order_by(MemoryCore.key.asc())
|
|
)
|
|
rows = result.scalars().all()
|
|
out: list[dict[str, str]] = []
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
|
if plaintext is not None:
|
|
out.append({"label": row.key, "value": plaintext})
|
|
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
|
|
return out
|
|
|
|
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
|
"""Return a single core memory block value by label."""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return None
|
|
|
|
result = await self._db.execute(
|
|
select(MemoryCore).where(
|
|
MemoryCore.user_id == user_id,
|
|
MemoryCore.key == label,
|
|
)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
if row is None:
|
|
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
|
|
return None
|
|
value = _safe_decrypt(fernet, row.value_encrypted)
|
|
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
|
|
return value
|
|
|
|
async def delete_core(self, user_id: str, label: str) -> bool:
|
|
"""Delete a core memory block by label. Returns True if deleted."""
|
|
result = await self._db.execute(
|
|
select(MemoryCore).where(
|
|
MemoryCore.user_id == user_id,
|
|
MemoryCore.key == label,
|
|
)
|
|
)
|
|
row = result.scalar_one_or_none()
|
|
if row is None:
|
|
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
|
|
return False
|
|
|
|
await self._db.delete(row)
|
|
try:
|
|
await self._db.commit()
|
|
logger.info("memory: delete_core user=%s label=%s", user_id, label)
|
|
return True
|
|
except Exception as exc:
|
|
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
|
await self._db.rollback()
|
|
return False
|
|
|
|
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
|
"""Append content to a core block, creating it if missing."""
|
|
current = await self.get_core_block(user_id, label)
|
|
if current is None:
|
|
await self.update_core(user_id, label, content)
|
|
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
|
|
return
|
|
await self.update_core(user_id, label, f"{current}\n{content}")
|
|
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
|
|
|
|
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
|
"""Replace one exact string inside a core block. Returns False if not found."""
|
|
current = await self.get_core_block(user_id, label)
|
|
if current is None or old not in current:
|
|
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
|
|
return False
|
|
await self.update_core(user_id, label, current.replace(old, new, 1))
|
|
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
|
return True
|
|
|
|
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
|
"""Insert a long-term archival memory entry."""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return
|
|
|
|
encrypted = _encrypt(fernet, content)
|
|
row = MemoryAssociative(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
content_encrypted=encrypted,
|
|
embedding=None,
|
|
entity_type=source,
|
|
entity_id=None,
|
|
)
|
|
self._db.add(row)
|
|
try:
|
|
await self._db.commit()
|
|
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
|
|
except Exception as exc:
|
|
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
|
await self._db.rollback()
|
|
|
|
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
|
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return []
|
|
|
|
result = await self._db.execute(
|
|
select(MemoryAssociative)
|
|
.where(MemoryAssociative.user_id == user_id)
|
|
.order_by(MemoryAssociative.updated_at.desc())
|
|
.limit(100)
|
|
)
|
|
rows = result.scalars().all()
|
|
needle = query.strip().lower()
|
|
out: list[str] = []
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
|
if plaintext is None:
|
|
continue
|
|
if not needle or needle in plaintext.lower():
|
|
out.append(plaintext)
|
|
if len(out) >= max(top_k, 1):
|
|
break
|
|
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
|
return out
|
|
|
|
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
|
"""Search recall memory (episodic summaries) by keyword."""
|
|
fernet = await self._get_fernet(user_id)
|
|
if fernet is None:
|
|
return []
|
|
|
|
result = await self._db.execute(
|
|
select(MemoryEpisodic)
|
|
.where(MemoryEpisodic.user_id == user_id)
|
|
.order_by(MemoryEpisodic.created_at.desc())
|
|
.limit(100)
|
|
)
|
|
rows = result.scalars().all()
|
|
needle = query.strip().lower()
|
|
out: list[str] = []
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
|
if plaintext is None:
|
|
continue
|
|
if not needle or needle in plaintext.lower():
|
|
out.append(plaintext)
|
|
if len(out) >= max(top_k, 1):
|
|
break
|
|
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
|
return out
|
|
|
|
# ── Private helpers ───────────────────────────────────────────────────────
|
|
|
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
|
"""Load the user's Fernet key from DB. Returns None if missing."""
|
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if user is None or not user.encryption_key:
|
|
logger.warning("memory: no encryption_key for user=%s", user_id)
|
|
return None
|
|
return Fernet(user.encryption_key.encode())
|
|
|
|
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
|
"""Load lightweight user debug fields for trace logs."""
|
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
if user is None:
|
|
return {"tier": None}
|
|
return {
|
|
"tier": user.tier,
|
|
}
|
|
|
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
|
result = await self._db.execute(
|
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
|
)
|
|
rows = result.scalars().all()
|
|
out: dict[str, str] = {}
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
|
if plaintext is not None:
|
|
out[row.key] = plaintext
|
|
return out
|
|
|
|
async def _load_associative(
|
|
self, user_id: str, message: str, fernet: Fernet
|
|
) -> list[str]:
|
|
"""Load top-k associative memories.
|
|
|
|
Production: uses pgvector cosine similarity on the message embedding.
|
|
Current implementation: keyword-based fallback (no external embedding call)
|
|
so tests pass without a live OpenAI key.
|
|
"""
|
|
result = await self._db.execute(
|
|
select(MemoryAssociative)
|
|
.where(MemoryAssociative.user_id == user_id)
|
|
.order_by(MemoryAssociative.updated_at.desc())
|
|
.limit(_ASSOCIATIVE_TOP_K)
|
|
)
|
|
rows = result.scalars().all()
|
|
out: list[str] = []
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
|
if plaintext is not None:
|
|
out.append(plaintext)
|
|
return out
|
|
|
|
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
|
|
result = await self._db.execute(
|
|
select(MemoryEpisodic)
|
|
.where(MemoryEpisodic.user_id == user_id)
|
|
.order_by(MemoryEpisodic.created_at.desc())
|
|
.limit(_EPISODIC_RECENT_N)
|
|
)
|
|
rows = result.scalars().all()
|
|
out: list[str] = []
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
|
if plaintext is not None:
|
|
out.append(plaintext)
|
|
return out
|
|
|
|
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
|
result = await self._db.execute(
|
|
select(MemoryProactive)
|
|
.where(
|
|
MemoryProactive.user_id == user_id,
|
|
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
|
|
)
|
|
.order_by(MemoryProactive.confidence.desc())
|
|
)
|
|
rows = result.scalars().all()
|
|
out: list[str] = []
|
|
for row in rows:
|
|
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
|
|
if plaintext is not None:
|
|
out.append(plaintext)
|
|
return out
|
|
|
|
|
|
# ── Encryption helpers ────────────────────────────────────────────────────────
|
|
|
|
def _encrypt(fernet: Fernet, plaintext: str) -> str:
|
|
return fernet.encrypt(plaintext.encode()).decode()
|
|
|
|
|
|
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
|
|
"""Decrypt and return plaintext, or None on error (corrupted/wrong key)."""
|
|
try:
|
|
return fernet.decrypt(ciphertext.encode()).decode()
|
|
except (InvalidToken, Exception) as exc:
|
|
logger.warning("memory: decrypt failed: %s", exc)
|
|
return None
|