"""Memory Middleware — enrich requests with memory context and store interactions. Four-tier memory model (MemGPT-style): core — persistent key/value user preferences, always injected associative — semantic similarity search via pgvector (top-k) episodic — recent session summaries (last N) proactive — behavioral patterns above confidence threshold All memory content is encrypted at rest using the per-user Fernet key stored in User.encryption_key. Decryption happens in-memory only. Usage: memory = MemoryMiddleware(db_session) context = await memory.enrich_context(user_id, message) # ... run agent ... await memory.store_episode(user_id, session_id, message, response) """ from __future__ import annotations import asyncio import logging import uuid from datetime import datetime, timezone from typing import Any from cryptography.fernet import Fernet, InvalidToken from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models import ( ExtractionQueue, MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive, MemoryRelation, User, ) logger = logging.getLogger(__name__) def _now() -> datetime: return datetime.now(timezone.utc) # Tuning constants _ASSOCIATIVE_TOP_K = 5 _EPISODIC_RECENT_N = 10 _PROACTIVE_CONFIDENCE_THRESHOLD = 0.6 class MemoryMiddleware: """Enrich orchestrator context with memory and persist interactions after.""" def __init__(self, db: AsyncSession) -> None: self._db = db # ── Public API ──────────────────────────────────────────────────────────── async def enrich_context( self, user_id: str, message: str, trace_id: str | None = None, session_id: str | None = None, ) -> dict[str, Any]: """Build memory context dict to inject into the orchestrator before LLM call. Returns a dict with keys: core_memory — {key: plaintext_value, ...} associative_memory — [plaintext_content, ...] (top-k by keyword match) episodic_memory — [plaintext_summary, ...] (most recent N) proactive_hints — [plaintext_pattern, ...] (above threshold) relational_memory — ["subject --predicate--> object", ...] (top 10, Pro+) """ fernet = await self._get_fernet(user_id) if fernet is None: return {} user_dbg = await self._get_user_debug(user_id) user_tier: str = user_dbg.get("tier") or "free" core = await self._load_core(user_id, fernet) associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier) episodic = await self._load_episodic(user_id, fernet, session_id=session_id) proactive = await self._load_proactive(user_id, fernet) relational = await self._load_relational(user_id, user_tier=user_tier) logger.info( "memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d relational=%d", trace_id or "-", user_id, user_tier, len(core), len(associative), len(episodic), len(proactive), len(relational), ) return { "core_memory": core, "associative_memory": associative, "episodic_memory": episodic, "proactive_hints": proactive, "relational_memory": relational, } async def store_episode( self, user_id: str, session_id: str, message: str, response: str, trace_id: str | None = None, ) -> None: """Summarise and store a completed interaction in episodic memory. The summary is a simple heuristic concatenation (no LLM call) to keep latency low. After committing the episode row, dispatches the Mem0-style extraction pipeline: - Pro/Power/Team → asyncio.create_task (fire-and-forget, realtime). - Free → enqueue an ExtractionQueue row for the daily cron. """ fernet = await self._get_fernet(user_id) if fernet is None: return summary = f"User: {message[:200]}\nAssistant: {response[:200]}" encrypted = _encrypt(fernet, summary) episode = MemoryEpisodic( id=str(uuid.uuid4()), user_id=user_id, summary_encrypted=encrypted, session_id=session_id, ) self._db.add(episode) episode_id: str = episode.id try: await self._db.commit() user_dbg = await self._get_user_debug(user_id) tier = user_dbg.get("tier") or "free" logger.info( "memory: store_episode trace=%s user=%s tier=%s session=%s", trace_id or "-", user_id, tier, session_id, ) except Exception as exc: logger.error("memory: store_episode failed user=%s: %s", user_id, exc) await self._db.rollback() return # ── Dispatch extraction pipeline (Phase 2) ──────────────────────────── await self._dispatch_extraction( user_id=user_id, episode_id=episode_id, last_user_msg=message, last_assistant_msg=response, session_id=session_id, ) async def _dispatch_extraction( self, user_id: str, episode_id: str, last_user_msg: str, last_assistant_msg: str, session_id: str | None, ) -> None: """Route extraction to realtime task or batch queue based on user tier.""" from app.billing.tier_manager import tier_manager # noqa: PLC0415 tier = await tier_manager.get_tier(user_id, self._db) if tier_manager.check_feature(tier, "realtime_extraction"): # Pro/Power/Team: fire-and-forget in the background. # Must open a fresh session — request session closes after handler returns. from app.core.memory_extraction import run_extraction # noqa: PLC0415 from app.db import async_session # noqa: PLC0415 async def _task() -> None: try: async with async_session() as fresh_db: await run_extraction( db=fresh_db, user_id=user_id, last_user_msg=last_user_msg, last_assistant_msg=last_assistant_msg, session_id=session_id, ) except Exception as exc: logger.warning( "memory: extraction task failed user=%s: %s", user_id, exc ) asyncio.create_task(_task()) logger.info("memory: realtime extraction dispatched user=%s", user_id) else: # Free tier: enqueue for daily batch cron. queue_row = ExtractionQueue( id=str(uuid.uuid4()), user_id=user_id, episode_id=episode_id, ) self._db.add(queue_row) try: await self._db.commit() logger.info( "memory: extraction enqueued (batch) user=%s episode=%s", user_id, episode_id, ) except Exception as exc: logger.warning( "memory: extraction queue insert failed user=%s: %s", user_id, exc ) await self._db.rollback() async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None: """Upsert a core memory key/value for a user.""" fernet = await self._get_fernet(user_id) if fernet is None: return encrypted = _encrypt(fernet, value) result = await self._db.execute( select(MemoryCore).where( MemoryCore.user_id == user_id, MemoryCore.key == key, ) ) existing = result.scalar_one_or_none() if existing is not None: existing.value_encrypted = encrypted else: self._db.add(MemoryCore( id=str(uuid.uuid4()), user_id=user_id, key=key, value_encrypted=encrypted, )) try: await self._db.commit() user_dbg = await self._get_user_debug(user_id) logger.info( "memory: update_core trace=%s user=%s tier=%s key=%s", trace_id or "-", user_id, user_dbg.get("tier") or "-", key, ) except Exception as exc: logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc) await self._db.rollback() async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]: """Return core memory as editable blocks (label/value).""" fernet = await self._get_fernet(user_id) if fernet is None: return [] result = await self._db.execute( select(MemoryCore) .where(MemoryCore.user_id == user_id) .order_by(MemoryCore.key.asc()) ) rows = result.scalars().all() out: list[dict[str, str]] = [] for row in rows: plaintext = _safe_decrypt(fernet, row.value_encrypted) if plaintext is not None: out.append({"label": row.key, "value": plaintext}) logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out)) return out async def get_core_block(self, user_id: str, label: str) -> str | None: """Return a single core memory block value by label.""" fernet = await self._get_fernet(user_id) if fernet is None: return None result = await self._db.execute( select(MemoryCore).where( MemoryCore.user_id == user_id, MemoryCore.key == label, ) ) row = result.scalar_one_or_none() if row is None: logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label) return None value = _safe_decrypt(fernet, row.value_encrypted) logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0) return value async def delete_core(self, user_id: str, label: str) -> bool: """Delete a core memory block by label. Returns True if deleted.""" result = await self._db.execute( select(MemoryCore).where( MemoryCore.user_id == user_id, MemoryCore.key == label, ) ) row = result.scalar_one_or_none() if row is None: logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label) return False await self._db.delete(row) try: await self._db.commit() logger.info("memory: delete_core user=%s label=%s", user_id, label) return True except Exception as exc: logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc) await self._db.rollback() return False async def append_core(self, user_id: str, label: str, content: str) -> None: """Append content to a core block, creating it if missing.""" current = await self.get_core_block(user_id, label) if current is None: await self.update_core(user_id, label, content) logger.info("memory: append_core user=%s label=%s created=1", user_id, label) return await self.update_core(user_id, label, f"{current}\n{content}") logger.info("memory: append_core user=%s label=%s created=0", user_id, label) async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool: """Replace one exact string inside a core block. Returns False if not found.""" current = await self.get_core_block(user_id, label) if current is None or old not in current: logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label) return False await self.update_core(user_id, label, current.replace(old, new, 1)) logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label) return True async def store_associative( self, user_id: str, content: str, entity_type: str | None = None, entity_id: str | None = None, ) -> None: """Store associative memory; embed if user tier has real_embeddings.""" from app.billing.tier_manager import tier_manager # noqa: PLC0415 from app.core.embeddings import embed_text # noqa: PLC0415 fernet = await self._get_fernet(user_id) if fernet is None: return encrypted = _encrypt(fernet, content) user_dbg = await self._get_user_debug(user_id) user_tier = user_dbg.get("tier") or "free" embedding: list[float] | None = None if tier_manager.check_feature(user_tier, "real_embeddings"): embedding = await embed_text(content) row = MemoryAssociative( id=str(uuid.uuid4()), user_id=user_id, content_encrypted=encrypted, embedding=embedding, entity_type=entity_type, entity_id=entity_id, ) self._db.add(row) try: await self._db.commit() logger.info( "memory: store_associative user=%s embedded=%s", user_id, embedding is not None, ) except Exception as exc: logger.error("memory: store_associative failed user=%s: %s", user_id, exc) await self._db.rollback() async def upsert_relation( self, user_id: str, subject: str, subject_type: str, predicate: str, object_: str, object_type: str, *, confidence: float = 0.7, source_episode_id: str | None = None, notes: str | None = None, ) -> None: """Insert or update a relation row. Matches on (user_id, subject_label, predicate, object_label). subject_label / object_label are plaintext entity identifiers — not encrypted. notes is optional; encrypted with user Fernet if provided. """ from app.billing.tier_manager import tier_manager # noqa: PLC0415 user_dbg = await self._get_user_debug(user_id) user_tier = user_dbg.get("tier") or "free" if not tier_manager.check_feature(user_tier, "relational_memory"): logger.debug("memory: upsert_relation skipped (tier=%s no relational_memory)", user_tier) return notes_encrypted: bytes | None = None if notes: fernet = await self._get_fernet(user_id) if fernet: notes_encrypted = fernet.encrypt(notes.encode()) result = await self._db.execute( select(MemoryRelation).where( MemoryRelation.user_id == user_id, MemoryRelation.subject_label == subject, MemoryRelation.predicate == predicate, MemoryRelation.object_label == object_, ) ) existing = result.scalar_one_or_none() if existing is not None: existing.subject_type = subject_type existing.object_type = object_type existing.confidence = confidence existing.last_confirmed_at = _now() if notes_encrypted is not None: existing.notes_encrypted = notes_encrypted else: self._db.add(MemoryRelation( id=str(uuid.uuid4()), user_id=user_id, subject_label=subject, subject_type=subject_type, predicate=predicate, object_label=object_, object_type=object_type, confidence=confidence, source_episode_id=source_episode_id, notes_encrypted=notes_encrypted, )) try: await self._db.commit() logger.info( "memory: upsert_relation user=%s subject=%s predicate=%s object=%s", user_id, subject, predicate, object_, ) except Exception as exc: logger.error("memory: upsert_relation failed user=%s: %s", user_id, exc) await self._db.rollback() async def query_relations( self, user_id: str, subject: str | None = None, predicate: str | None = None, object_: str | None = None, limit: int = 20, ) -> list[MemoryRelation]: """Query relation rows for a user with optional filters.""" q = select(MemoryRelation).where(MemoryRelation.user_id == user_id) if subject is not None: q = q.where(MemoryRelation.subject_label == subject) if predicate is not None: q = q.where(MemoryRelation.predicate == predicate) if object_ is not None: q = q.where(MemoryRelation.object_label == object_) q = q.order_by(MemoryRelation.confidence.desc()).limit(limit) result = await self._db.execute(q) return list(result.scalars().all()) async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None: """Insert a long-term archival memory entry.""" fernet = await self._get_fernet(user_id) if fernet is None: return encrypted = _encrypt(fernet, content) row = MemoryAssociative( id=str(uuid.uuid4()), user_id=user_id, content_encrypted=encrypted, embedding=None, entity_type=source, entity_id=None, ) self._db.add(row) try: await self._db.commit() logger.info("memory: insert_archival user=%s source=%s", user_id, source) except Exception as exc: logger.error("memory: insert_archival failed user=%s: %s", user_id, exc) await self._db.rollback() async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]: """Search archival memory (keyword fallback; semantic ranking can replace this).""" fernet = await self._get_fernet(user_id) if fernet is None: return [] result = await self._db.execute( select(MemoryAssociative) .where(MemoryAssociative.user_id == user_id) .order_by(MemoryAssociative.updated_at.desc()) .limit(100) ) rows = result.scalars().all() needle = query.strip().lower() out: list[str] = [] for row in rows: plaintext = _safe_decrypt(fernet, row.content_encrypted) if plaintext is None: continue if not needle or needle in plaintext.lower(): out.append(plaintext) if len(out) >= max(top_k, 1): break logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out)) return out async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]: """Search recall memory (episodic summaries) by keyword.""" fernet = await self._get_fernet(user_id) if fernet is None: return [] result = await self._db.execute( select(MemoryEpisodic) .where(MemoryEpisodic.user_id == user_id) .order_by(MemoryEpisodic.created_at.desc()) .limit(100) ) rows = result.scalars().all() needle = query.strip().lower() out: list[str] = [] for row in rows: plaintext = _safe_decrypt(fernet, row.summary_encrypted) if plaintext is None: continue if not needle or needle in plaintext.lower(): out.append(plaintext) if len(out) >= max(top_k, 1): break logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out)) return out # ── Private helpers ─────────────────────────────────────────────────────── async def _get_fernet(self, user_id: str) -> Fernet | None: """Load the user's Fernet key from DB. Returns None if missing.""" result = await self._db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None or not user.encryption_key: logger.warning("memory: no encryption_key for user=%s", user_id) return None return Fernet(user.encryption_key.encode()) async def _get_user_debug(self, user_id: str) -> dict[str, str | None]: """Load lightweight user debug fields for trace logs.""" from app.config.settings import settings # noqa: PLC0415 from app.models import Subscription # noqa: PLC0415 result = await self._db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None: return {"tier": None} sub_result = await self._db.execute( select(Subscription.tier).where(Subscription.user_id == user_id) ) sub_tier: str | None = sub_result.scalar_one_or_none() if sub_tier: tier = sub_tier elif settings.ENV == "dev": tier = "power" else: tier = user.tier or "free" return {"tier": tier} async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]: result = await self._db.execute( select(MemoryCore).where(MemoryCore.user_id == user_id) ) rows = result.scalars().all() out: dict[str, str] = {} for row in rows: plaintext = _safe_decrypt(fernet, row.value_encrypted) if plaintext is not None: out[row.key] = plaintext return out async def _load_associative( self, user_id: str, message: str, fernet: Fernet, *, user_tier: str = "free" ) -> list[str]: """Load top-k associative memories. Pro+: pgvector cosine similarity on the message embedding (real_embeddings feature). Free / embedding failure: keyword-ordered fallback (most recent rows). """ from app.billing.tier_manager import tier_manager # noqa: PLC0415 from app.core.embeddings import embed_text # noqa: PLC0415 if tier_manager.check_feature(user_tier, "real_embeddings"): vec = await embed_text(message) if vec is not None: try: result = await self._db.execute( select(MemoryAssociative) .where( MemoryAssociative.user_id == user_id, MemoryAssociative.embedding.isnot(None), ) .order_by(MemoryAssociative.embedding.cosine_distance(vec)) .limit(_ASSOCIATIVE_TOP_K) ) rows = result.scalars().all() out: list[str] = [] for row in rows: plaintext = _safe_decrypt(fernet, row.content_encrypted) if plaintext is not None: out.append(plaintext) logger.info( "memory: _load_associative user=%s mode=vector hits=%d", user_id, len(out), ) return out except Exception as exc: logger.warning( "memory: vector search failed user=%s, falling back to keyword: %s", user_id, exc, ) # Keyword fallback: most recent rows result = await self._db.execute( select(MemoryAssociative) .where(MemoryAssociative.user_id == user_id) .order_by(MemoryAssociative.updated_at.desc()) .limit(_ASSOCIATIVE_TOP_K) ) rows = result.scalars().all() out = [] for row in rows: plaintext = _safe_decrypt(fernet, row.content_encrypted) if plaintext is not None: out.append(plaintext) return out async def _load_episodic( self, user_id: str, fernet: Fernet, session_id: str | None = None, ) -> list[str]: query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id) if session_id: query = query.where(MemoryEpisodic.session_id == session_id) result = await self._db.execute( query .order_by(MemoryEpisodic.created_at.desc()) .limit(_EPISODIC_RECENT_N) ) rows = result.scalars().all() out: list[str] = [] for row in rows: plaintext = _safe_decrypt(fernet, row.summary_encrypted) if plaintext is not None: out.append(plaintext) return out async def _load_relational(self, user_id: str, *, user_tier: str = "free") -> list[str]: """Return top-10 relation strings for Pro+ users; empty list for Free.""" from app.billing.tier_manager import tier_manager # noqa: PLC0415 if not tier_manager.check_feature(user_tier, "relational_memory"): return [] result = await self._db.execute( select(MemoryRelation) .where(MemoryRelation.user_id == user_id) .order_by(MemoryRelation.confidence.desc()) .limit(10) ) rows = result.scalars().all() out = [ f"{r.subject_label} --{r.predicate}--> {r.object_label}" for r in rows ] return out async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]: result = await self._db.execute( select(MemoryProactive) .where( MemoryProactive.user_id == user_id, MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD, ) .order_by(MemoryProactive.confidence.desc()) ) rows = result.scalars().all() out: list[str] = [] for row in rows: plaintext = _safe_decrypt(fernet, row.pattern_encrypted) if plaintext is not None: out.append(plaintext) return out # ── Encryption helpers ──────────────────────────────────────────────────────── def _encrypt(fernet: Fernet, plaintext: str) -> str: return fernet.encrypt(plaintext.encode()).decode() def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None: """Decrypt and return plaintext, or None on error (corrupted/wrong key).""" try: return fernet.decrypt(ciphertext.encode()).decode() except (InvalidToken, Exception) as exc: logger.warning("memory: decrypt failed: %s", exc) return None