"""Memory Middleware — enrich requests with memory context and store interactions. Four-tier memory model (MemGPT-style): core — persistent key/value user preferences, always injected associative — semantic similarity search via pgvector (top-k) episodic — recent session summaries (last N) proactive — behavioral patterns above confidence threshold All memory content is encrypted at rest using the per-user Fernet key stored in User.encryption_key. Decryption happens in-memory only. Usage: memory = MemoryMiddleware(db_session) context = await memory.enrich_context(user_id, message) # ... run agent ... await memory.store_episode(user_id, session_id, message, response) """ from __future__ import annotations import logging import uuid from typing import Any from cryptography.fernet import Fernet, InvalidToken from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models import ( MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive, User, ) logger = logging.getLogger(__name__) # Tuning constants _ASSOCIATIVE_TOP_K = 5 _EPISODIC_RECENT_N = 10 _PROACTIVE_CONFIDENCE_THRESHOLD = 0.6 class MemoryMiddleware: """Enrich orchestrator context with memory and persist interactions after.""" def __init__(self, db: AsyncSession) -> None: self._db = db # ── Public API ──────────────────────────────────────────────────────────── async def enrich_context( self, user_id: str, message: str, trace_id: str | None = None, session_id: str | None = None, ) -> dict[str, Any]: """Build memory context dict to inject into the orchestrator before LLM call. Returns a dict with keys: core_memory — {key: plaintext_value, ...} associative_memory — [plaintext_content, ...] (top-k by keyword match) episodic_memory — [plaintext_summary, ...] (most recent N) proactive_hints — [plaintext_pattern, ...] (above threshold) """ fernet = await self._get_fernet(user_id) if fernet is None: return {} user_dbg = await self._get_user_debug(user_id) user_tier: str = user_dbg.get("tier") or "free" core = await self._load_core(user_id, fernet) associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier) episodic = await self._load_episodic(user_id, fernet, session_id=session_id) proactive = await self._load_proactive(user_id, fernet) logger.info( "memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d", trace_id or "-", user_id, user_tier, len(core), len(associative), len(episodic), len(proactive), ) return { "core_memory": core, "associative_memory": associative, "episodic_memory": episodic, "proactive_hints": proactive, } async def store_episode( self, user_id: str, session_id: str, message: str, response: str, trace_id: str | None = None, ) -> None: """Summarise and store a completed interaction in episodic memory. The summary is a simple heuristic concatenation (no LLM call) to keep latency low. Full LLM summarisation can be added in a later step. """ fernet = await self._get_fernet(user_id) if fernet is None: return summary = f"User: {message[:200]}\nAssistant: {response[:200]}" encrypted = _encrypt(fernet, summary) row = MemoryEpisodic( id=str(uuid.uuid4()), user_id=user_id, summary_encrypted=encrypted, session_id=session_id, ) self._db.add(row) try: await self._db.commit() user_dbg = await self._get_user_debug(user_id) logger.info( "memory: store_episode trace=%s user=%s tier=%s session=%s", trace_id or "-", user_id, user_dbg.get("tier") or "-", session_id, ) except Exception as exc: logger.error("memory: store_episode failed user=%s: %s", user_id, exc) await self._db.rollback() async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None: """Upsert a core memory key/value for a user.""" fernet = await self._get_fernet(user_id) if fernet is None: return encrypted = _encrypt(fernet, value) result = await self._db.execute( select(MemoryCore).where( MemoryCore.user_id == user_id, MemoryCore.key == key, ) ) existing = result.scalar_one_or_none() if existing is not None: existing.value_encrypted = encrypted else: self._db.add(MemoryCore( id=str(uuid.uuid4()), user_id=user_id, key=key, value_encrypted=encrypted, )) try: await self._db.commit() user_dbg = await self._get_user_debug(user_id) logger.info( "memory: update_core trace=%s user=%s tier=%s key=%s", trace_id or "-", user_id, user_dbg.get("tier") or "-", key, ) except Exception as exc: logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc) await self._db.rollback() async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]: """Return core memory as editable blocks (label/value).""" fernet = await self._get_fernet(user_id) if fernet is None: return [] result = await self._db.execute( select(MemoryCore) .where(MemoryCore.user_id == user_id) .order_by(MemoryCore.key.asc()) ) rows = result.scalars().all() out: list[dict[str, str]] = [] for row in rows: plaintext = _safe_decrypt(fernet, row.value_encrypted) if plaintext is not None: out.append({"label": row.key, "value": plaintext}) logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out)) return out async def get_core_block(self, user_id: str, label: str) -> str | None: """Return a single core memory block value by label.""" fernet = await self._get_fernet(user_id) if fernet is None: return None result = await self._db.execute( select(MemoryCore).where( MemoryCore.user_id == user_id, MemoryCore.key == label, ) ) row = result.scalar_one_or_none() if row is None: logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label) return None value = _safe_decrypt(fernet, row.value_encrypted) logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0) return value async def delete_core(self, user_id: str, label: str) -> bool: """Delete a core memory block by label. Returns True if deleted.""" result = await self._db.execute( select(MemoryCore).where( MemoryCore.user_id == user_id, MemoryCore.key == label, ) ) row = result.scalar_one_or_none() if row is None: logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label) return False await self._db.delete(row) try: await self._db.commit() logger.info("memory: delete_core user=%s label=%s", user_id, label) return True except Exception as exc: logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc) await self._db.rollback() return False async def append_core(self, user_id: str, label: str, content: str) -> None: """Append content to a core block, creating it if missing.""" current = await self.get_core_block(user_id, label) if current is None: await self.update_core(user_id, label, content) logger.info("memory: append_core user=%s label=%s created=1", user_id, label) return await self.update_core(user_id, label, f"{current}\n{content}") logger.info("memory: append_core user=%s label=%s created=0", user_id, label) async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool: """Replace one exact string inside a core block. Returns False if not found.""" current = await self.get_core_block(user_id, label) if current is None or old not in current: logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label) return False await self.update_core(user_id, label, current.replace(old, new, 1)) logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label) return True async def store_associative( self, user_id: str, content: str, entity_type: str | None = None, entity_id: str | None = None, ) -> None: """Store associative memory; embed if user tier has real_embeddings.""" from app.billing.tier_manager import tier_manager # noqa: PLC0415 from app.core.embeddings import embed_text # noqa: PLC0415 fernet = await self._get_fernet(user_id) if fernet is None: return encrypted = _encrypt(fernet, content) user_dbg = await self._get_user_debug(user_id) user_tier = user_dbg.get("tier") or "free" embedding: list[float] | None = None if tier_manager.check_feature(user_tier, "real_embeddings"): embedding = await embed_text(content) row = MemoryAssociative( id=str(uuid.uuid4()), user_id=user_id, content_encrypted=encrypted, embedding=embedding, entity_type=entity_type, entity_id=entity_id, ) self._db.add(row) try: await self._db.commit() logger.info( "memory: store_associative user=%s embedded=%s", user_id, embedding is not None, ) except Exception as exc: logger.error("memory: store_associative failed user=%s: %s", user_id, exc) await self._db.rollback() async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None: """Insert a long-term archival memory entry.""" fernet = await self._get_fernet(user_id) if fernet is None: return encrypted = _encrypt(fernet, content) row = MemoryAssociative( id=str(uuid.uuid4()), user_id=user_id, content_encrypted=encrypted, embedding=None, entity_type=source, entity_id=None, ) self._db.add(row) try: await self._db.commit() logger.info("memory: insert_archival user=%s source=%s", user_id, source) except Exception as exc: logger.error("memory: insert_archival failed user=%s: %s", user_id, exc) await self._db.rollback() async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]: """Search archival memory (keyword fallback; semantic ranking can replace this).""" fernet = await self._get_fernet(user_id) if fernet is None: return [] result = await self._db.execute( select(MemoryAssociative) .where(MemoryAssociative.user_id == user_id) .order_by(MemoryAssociative.updated_at.desc()) .limit(100) ) rows = result.scalars().all() needle = query.strip().lower() out: list[str] = [] for row in rows: plaintext = _safe_decrypt(fernet, row.content_encrypted) if plaintext is None: continue if not needle or needle in plaintext.lower(): out.append(plaintext) if len(out) >= max(top_k, 1): break logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out)) return out async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]: """Search recall memory (episodic summaries) by keyword.""" fernet = await self._get_fernet(user_id) if fernet is None: return [] result = await self._db.execute( select(MemoryEpisodic) .where(MemoryEpisodic.user_id == user_id) .order_by(MemoryEpisodic.created_at.desc()) .limit(100) ) rows = result.scalars().all() needle = query.strip().lower() out: list[str] = [] for row in rows: plaintext = _safe_decrypt(fernet, row.summary_encrypted) if plaintext is None: continue if not needle or needle in plaintext.lower(): out.append(plaintext) if len(out) >= max(top_k, 1): break logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out)) return out # ── Private helpers ─────────────────────────────────────────────────────── async def _get_fernet(self, user_id: str) -> Fernet | None: """Load the user's Fernet key from DB. Returns None if missing.""" result = await self._db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None or not user.encryption_key: logger.warning("memory: no encryption_key for user=%s", user_id) return None return Fernet(user.encryption_key.encode()) async def _get_user_debug(self, user_id: str) -> dict[str, str | None]: """Load lightweight user debug fields for trace logs.""" result = await self._db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None: return {"tier": None} return { "tier": user.tier, } async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]: result = await self._db.execute( select(MemoryCore).where(MemoryCore.user_id == user_id) ) rows = result.scalars().all() out: dict[str, str] = {} for row in rows: plaintext = _safe_decrypt(fernet, row.value_encrypted) if plaintext is not None: out[row.key] = plaintext return out async def _load_associative( self, user_id: str, message: str, fernet: Fernet, *, user_tier: str = "free" ) -> list[str]: """Load top-k associative memories. Pro+: pgvector cosine similarity on the message embedding (real_embeddings feature). Free / embedding failure: keyword-ordered fallback (most recent rows). """ from app.billing.tier_manager import tier_manager # noqa: PLC0415 from app.core.embeddings import embed_text # noqa: PLC0415 if tier_manager.check_feature(user_tier, "real_embeddings"): vec = await embed_text(message) if vec is not None: try: result = await self._db.execute( select(MemoryAssociative) .where( MemoryAssociative.user_id == user_id, MemoryAssociative.embedding.isnot(None), ) .order_by(MemoryAssociative.embedding.cosine_distance(vec)) .limit(_ASSOCIATIVE_TOP_K) ) rows = result.scalars().all() out: list[str] = [] for row in rows: plaintext = _safe_decrypt(fernet, row.content_encrypted) if plaintext is not None: out.append(plaintext) logger.info( "memory: _load_associative user=%s mode=vector hits=%d", user_id, len(out), ) return out except Exception as exc: logger.warning( "memory: vector search failed user=%s, falling back to keyword: %s", user_id, exc, ) # Keyword fallback: most recent rows result = await self._db.execute( select(MemoryAssociative) .where(MemoryAssociative.user_id == user_id) .order_by(MemoryAssociative.updated_at.desc()) .limit(_ASSOCIATIVE_TOP_K) ) rows = result.scalars().all() out = [] for row in rows: plaintext = _safe_decrypt(fernet, row.content_encrypted) if plaintext is not None: out.append(plaintext) return out async def _load_episodic( self, user_id: str, fernet: Fernet, session_id: str | None = None, ) -> list[str]: query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id) if session_id: query = query.where(MemoryEpisodic.session_id == session_id) result = await self._db.execute( query .order_by(MemoryEpisodic.created_at.desc()) .limit(_EPISODIC_RECENT_N) ) rows = result.scalars().all() out: list[str] = [] for row in rows: plaintext = _safe_decrypt(fernet, row.summary_encrypted) if plaintext is not None: out.append(plaintext) return out async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]: result = await self._db.execute( select(MemoryProactive) .where( MemoryProactive.user_id == user_id, MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD, ) .order_by(MemoryProactive.confidence.desc()) ) rows = result.scalars().all() out: list[str] = [] for row in rows: plaintext = _safe_decrypt(fernet, row.pattern_encrypted) if plaintext is not None: out.append(plaintext) return out # ── Encryption helpers ──────────────────────────────────────────────────────── def _encrypt(fernet: Fernet, plaintext: str) -> str: return fernet.encrypt(plaintext.encode()).decode() def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None: """Decrypt and return plaintext, or None on error (corrupted/wrong key).""" try: return fernet.decrypt(ciphertext.encode()).decode() except (InvalidToken, Exception) as exc: logger.warning("memory: decrypt failed: %s", exc) return None