"""Memory Middleware — adapted for Chat Service. Uses shared.models instead of app.models. Otherwise identical to the monolith's app/core/memory_middleware.py. """ from __future__ import annotations import logging import uuid from typing import Any from cryptography.fernet import Fernet, InvalidToken from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from shared.models import ( MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive, User, ) logger = logging.getLogger(__name__) _ASSOCIATIVE_TOP_K = 5 _EPISODIC_RECENT_N = 10 _PROACTIVE_CONFIDENCE_THRESHOLD = 0.6 class MemoryMiddleware: def __init__(self, db: AsyncSession) -> None: self._db = db async def enrich_context( self, user_id: str, message: str, trace_id: str | None = None, session_id: str | None = None, ) -> dict[str, Any]: fernet = await self._get_fernet(user_id) if fernet is None: return {} core = await self._load_core(user_id, fernet) associative = await self._load_associative(user_id, message, fernet) episodic = await self._load_episodic(user_id, fernet, session_id=session_id) proactive = await self._load_proactive(user_id, fernet) logger.info( "memory: enrich_context trace=%s user=%s core=%d assoc=%d episodic=%d proactive=%d", trace_id or "-", user_id, len(core), len(associative), len(episodic), len(proactive), ) return { "core_memory": core, "associative_memory": associative, "episodic_memory": episodic, "proactive_hints": proactive, } async def store_episode( self, user_id: str, session_id: str, message: str, response: str, trace_id: str | None = None, ) -> None: fernet = await self._get_fernet(user_id) if fernet is None: return summary = f"User: {message[:200]}\nAssistant: {response[:200]}" encrypted = _encrypt(fernet, summary) row = MemoryEpisodic( id=str(uuid.uuid4()), user_id=user_id, summary_encrypted=encrypted, session_id=session_id, ) self._db.add(row) try: await self._db.commit() except Exception as exc: logger.error("memory: store_episode failed user=%s: %s", user_id, exc) await self._db.rollback() async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None: fernet = await self._get_fernet(user_id) if fernet is None: return encrypted = _encrypt(fernet, value) result = await self._db.execute( select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == key) ) existing = result.scalar_one_or_none() if existing is not None: existing.value_encrypted = encrypted else: self._db.add(MemoryCore( id=str(uuid.uuid4()), user_id=user_id, key=key, value_encrypted=encrypted, )) try: await self._db.commit() except Exception as exc: logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc) await self._db.rollback() async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]: fernet = await self._get_fernet(user_id) if fernet is None: return [] result = await self._db.execute( select(MemoryCore).where(MemoryCore.user_id == user_id).order_by(MemoryCore.key.asc()) ) out: list[dict[str, str]] = [] for row in result.scalars().all(): plaintext = _safe_decrypt(fernet, row.value_encrypted) if plaintext is not None: out.append({"label": row.key, "value": plaintext}) return out async def get_core_block(self, user_id: str, label: str) -> str | None: fernet = await self._get_fernet(user_id) if fernet is None: return None result = await self._db.execute( select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label) ) row = result.scalar_one_or_none() if row is None: return None return _safe_decrypt(fernet, row.value_encrypted) async def delete_core(self, user_id: str, label: str) -> bool: result = await self._db.execute( select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label) ) row = result.scalar_one_or_none() if row is None: return False await self._db.delete(row) try: await self._db.commit() return True except Exception as exc: logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc) await self._db.rollback() return False async def append_core(self, user_id: str, label: str, content: str) -> None: current = await self.get_core_block(user_id, label) if current is None: await self.update_core(user_id, label, content) return await self.update_core(user_id, label, f"{current}\n{content}") async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool: current = await self.get_core_block(user_id, label) if current is None or old not in current: return False await self.update_core(user_id, label, current.replace(old, new, 1)) return True async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None: fernet = await self._get_fernet(user_id) if fernet is None: return encrypted = _encrypt(fernet, content) row = MemoryAssociative( id=str(uuid.uuid4()), user_id=user_id, content_encrypted=encrypted, embedding=None, entity_type=source, entity_id=None, ) self._db.add(row) try: await self._db.commit() except Exception as exc: logger.error("memory: insert_archival failed user=%s: %s", user_id, exc) await self._db.rollback() async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]: fernet = await self._get_fernet(user_id) if fernet is None: return [] result = await self._db.execute( select(MemoryAssociative).where(MemoryAssociative.user_id == user_id) .order_by(MemoryAssociative.updated_at.desc()).limit(100) ) needle = query.strip().lower() out: list[str] = [] for row in result.scalars().all(): plaintext = _safe_decrypt(fernet, row.content_encrypted) if plaintext is None: continue if not needle or needle in plaintext.lower(): out.append(plaintext) if len(out) >= max(top_k, 1): break return out async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]: fernet = await self._get_fernet(user_id) if fernet is None: return [] result = await self._db.execute( select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id) .order_by(MemoryEpisodic.created_at.desc()).limit(100) ) needle = query.strip().lower() out: list[str] = [] for row in result.scalars().all(): plaintext = _safe_decrypt(fernet, row.summary_encrypted) if plaintext is None: continue if not needle or needle in plaintext.lower(): out.append(plaintext) if len(out) >= max(top_k, 1): break return out # ── Private ─────────────────────────────────────────────────────── async def _get_fernet(self, user_id: str) -> Fernet | None: result = await self._db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None or not user.encryption_key: logger.warning("memory: no encryption_key for user=%s", user_id) return None return Fernet(user.encryption_key.encode()) async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]: result = await self._db.execute( select(MemoryCore).where(MemoryCore.user_id == user_id) ) out: dict[str, str] = {} for row in result.scalars().all(): plaintext = _safe_decrypt(fernet, row.value_encrypted) if plaintext is not None: out[row.key] = plaintext return out async def _load_associative(self, user_id: str, message: str, fernet: Fernet) -> list[str]: result = await self._db.execute( select(MemoryAssociative).where(MemoryAssociative.user_id == user_id) .order_by(MemoryAssociative.updated_at.desc()).limit(_ASSOCIATIVE_TOP_K) ) out: list[str] = [] for row in result.scalars().all(): plaintext = _safe_decrypt(fernet, row.content_encrypted) if plaintext is not None: out.append(plaintext) return out async def _load_episodic(self, user_id: str, fernet: Fernet, session_id: str | None = None) -> list[str]: query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id) if session_id: query = query.where(MemoryEpisodic.session_id == session_id) result = await self._db.execute( query.order_by(MemoryEpisodic.created_at.desc()).limit(_EPISODIC_RECENT_N) ) out: list[str] = [] for row in result.scalars().all(): plaintext = _safe_decrypt(fernet, row.summary_encrypted) if plaintext is not None: out.append(plaintext) return out async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]: result = await self._db.execute( select(MemoryProactive).where( MemoryProactive.user_id == user_id, MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD, ).order_by(MemoryProactive.confidence.desc()) ) out: list[str] = [] for row in result.scalars().all(): plaintext = _safe_decrypt(fernet, row.pattern_encrypted) if plaintext is not None: out.append(plaintext) return out def _encrypt(fernet: Fernet, plaintext: str) -> str: return fernet.encrypt(plaintext.encode()).decode() def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None: try: return fernet.decrypt(ciphertext.encode()).decode() except (InvalidToken, Exception) as exc: logger.warning("memory: decrypt failed: %s", exc) return None