"""Memory Middleware — enrich requests with memory context and store interactions. Four-tier memory model (MemGPT-style): core — persistent key/value user preferences, always injected associative — semantic similarity search via pgvector (top-k) episodic — recent session summaries (last N) proactive — behavioral patterns above confidence threshold All memory content is encrypted at rest using the per-user Fernet key stored in User.encryption_key. Decryption happens in-memory only. Usage: memory = MemoryMiddleware(db_session) context = await memory.enrich_context(user_id, message) # ... run agent ... await memory.store_episode(user_id, session_id, message, response) """ from __future__ import annotations import logging import uuid from typing import Any from cryptography.fernet import Fernet, InvalidToken from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models import ( MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive, User, ) logger = logging.getLogger(__name__) # Tuning constants _ASSOCIATIVE_TOP_K = 5 _EPISODIC_RECENT_N = 10 _PROACTIVE_CONFIDENCE_THRESHOLD = 0.6 class MemoryMiddleware: """Enrich agent context with memory and persist interactions after.""" def __init__(self, db: AsyncSession) -> None: self._db = db # ── Public API ──────────────────────────────────────────────────────────── async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]: """Build memory context dict to inject into the agent before LLM call. Returns a dict with keys: core_memory — {key: plaintext_value, ...} associative_memory — [plaintext_content, ...] (top-k by keyword match) episodic_memory — [plaintext_summary, ...] (most recent N) proactive_hints — [plaintext_pattern, ...] (above threshold) """ fernet = await self._get_fernet(user_id) if fernet is None: return {} core = await self._load_core(user_id, fernet) associative = await self._load_associative(user_id, message, fernet) episodic = await self._load_episodic(user_id, fernet) proactive = await self._load_proactive(user_id, fernet) return { "core_memory": core, "associative_memory": associative, "episodic_memory": episodic, "proactive_hints": proactive, } async def store_episode( self, user_id: str, session_id: str, message: str, response: str, ) -> None: """Summarise and store a completed interaction in episodic memory. The summary is a simple heuristic concatenation (no LLM call) to keep latency low. Full LLM summarisation can be added in a later step. """ fernet = await self._get_fernet(user_id) if fernet is None: return summary = f"User: {message[:200]}\nAssistant: {response[:200]}" encrypted = _encrypt(fernet, summary) row = MemoryEpisodic( id=str(uuid.uuid4()), user_id=user_id, summary_encrypted=encrypted, session_id=session_id, ) self._db.add(row) try: await self._db.commit() except Exception as exc: logger.error("memory: store_episode failed user=%s: %s", user_id, exc) await self._db.rollback() async def update_core(self, user_id: str, key: str, value: str) -> None: """Upsert a core memory key/value for a user.""" fernet = await self._get_fernet(user_id) if fernet is None: return encrypted = _encrypt(fernet, value) result = await self._db.execute( select(MemoryCore).where( MemoryCore.user_id == user_id, MemoryCore.key == key, ) ) existing = result.scalar_one_or_none() if existing is not None: existing.value_encrypted = encrypted else: self._db.add(MemoryCore( id=str(uuid.uuid4()), user_id=user_id, key=key, value_encrypted=encrypted, )) try: await self._db.commit() except Exception as exc: logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc) await self._db.rollback() # ── Private helpers ─────────────────────────────────────────────────────── async def _get_fernet(self, user_id: str) -> Fernet | None: """Load the user's Fernet key from DB. Returns None if missing.""" result = await self._db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None or not user.encryption_key: logger.warning("memory: no encryption_key for user=%s", user_id) return None return Fernet(user.encryption_key.encode()) async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]: result = await self._db.execute( select(MemoryCore).where(MemoryCore.user_id == user_id) ) rows = result.scalars().all() out: dict[str, str] = {} for row in rows: plaintext = _safe_decrypt(fernet, row.value_encrypted) if plaintext is not None: out[row.key] = plaintext return out async def _load_associative( self, user_id: str, message: str, fernet: Fernet ) -> list[str]: """Load top-k associative memories. Production: uses pgvector cosine similarity on the message embedding. Current implementation: keyword-based fallback (no external embedding call) so tests pass without a live OpenAI key. """ result = await self._db.execute( select(MemoryAssociative) .where(MemoryAssociative.user_id == user_id) .order_by(MemoryAssociative.updated_at.desc()) .limit(_ASSOCIATIVE_TOP_K) ) rows = result.scalars().all() out: list[str] = [] for row in rows: plaintext = _safe_decrypt(fernet, row.content_encrypted) if plaintext is not None: out.append(plaintext) return out async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]: result = await self._db.execute( select(MemoryEpisodic) .where(MemoryEpisodic.user_id == user_id) .order_by(MemoryEpisodic.created_at.desc()) .limit(_EPISODIC_RECENT_N) ) rows = result.scalars().all() out: list[str] = [] for row in rows: plaintext = _safe_decrypt(fernet, row.summary_encrypted) if plaintext is not None: out.append(plaintext) return out async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]: result = await self._db.execute( select(MemoryProactive) .where( MemoryProactive.user_id == user_id, MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD, ) .order_by(MemoryProactive.confidence.desc()) ) rows = result.scalars().all() out: list[str] = [] for row in rows: plaintext = _safe_decrypt(fernet, row.pattern_encrypted) if plaintext is not None: out.append(plaintext) return out # ── Encryption helpers ──────────────────────────────────────────────────────── def _encrypt(fernet: Fernet, plaintext: str) -> str: return fernet.encrypt(plaintext.encode()).decode() def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None: """Decrypt and return plaintext, or None on error (corrupted/wrong key).""" try: return fernet.decrypt(ciphertext.encode()).decode() except (InvalidToken, Exception) as exc: logger.warning("memory: decrypt failed: %s", exc) return None