"""Memory maintenance jobs — Phase 3/5. Three entrypoints called by the scheduler (APScheduler) registered in app/main.py: drain_extraction_queue(db) — Free-tier batch extraction (Phase 2/5). mine_proactive_patterns(db, user_id) — Power+ pattern mining (Phase 5). decay_relations(db, user_id) — confidence decay + pruning for memory_relations (Phase 3). All are safe to call manually or from tests; they never raise. """ from __future__ import annotations import json import logging import uuid from datetime import datetime, timedelta, timezone from cryptography.fernet import Fernet from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback from app.models import MemoryAssociative, MemoryEpisodic, MemoryProactive, MemoryRelation, User logger = logging.getLogger(__name__) # Decay parameters for relations _DECAY_FACTOR = 0.95 _DECAY_PERIOD_DAYS = 30 _PRUNE_THRESHOLD = 0.2 # Proactive pattern decay: 10 % per 7 days since last sighting _PROACTIVE_DECAY_FACTOR = 0.9 _PROACTIVE_DECAY_PERIOD_DAYS = 7 _PROACTIVE_PRUNE_THRESHOLD = 0.2 # Mining: require at least this many episodes to attempt pattern extraction _MIN_EPISODES_FOR_MINING = 3 _MINING_LOOKBACK_DAYS = 30 # Audit: caps to control token cost _AUDIT_MAX_FACTS = 50 _AUDIT_MAX_LABELS = 100 async def decay_relations(db: AsyncSession, user_id: str) -> None: """Apply confidence decay to all relation rows for a user. Decay rule: confidence *= 0.95 for every 30 days since last_confirmed_at. Rows whose confidence falls below 0.2 are deleted. Never raises — wraps in try/except. """ try: await _decay_relations_inner(db, user_id) except Exception as exc: logger.warning("memory_maintenance: decay_relations failed user=%s: %s", user_id, exc) async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None: result = await db.execute( select(MemoryRelation).where(MemoryRelation.user_id == user_id) ) rows = result.scalars().all() now = datetime.now(timezone.utc) deleted = 0 decayed = 0 for row in rows: reference = row.last_confirmed_at or row.created_at if reference is None: continue if reference.tzinfo is None: reference = reference.replace(tzinfo=timezone.utc) days_elapsed = (now - reference).days if days_elapsed < _DECAY_PERIOD_DAYS: continue periods = days_elapsed // _DECAY_PERIOD_DAYS new_confidence = row.confidence * (_DECAY_FACTOR ** periods) if new_confidence < _PRUNE_THRESHOLD: await db.delete(row) deleted += 1 logger.info( "memory_maintenance: pruned relation id=%s user=%s subject=%s predicate=%s " "confidence=%.3f (below threshold)", row.id, user_id, row.subject_label, row.predicate, new_confidence, ) else: row.confidence = new_confidence decayed += 1 try: await db.commit() logger.info( "memory_maintenance: decay_relations user=%s decayed=%d deleted=%d", user_id, decayed, deleted, ) except Exception as exc: logger.warning("memory_maintenance: decay_relations commit failed user=%s: %s", user_id, exc) await db.rollback() async def drain_extraction_queue(db: AsyncSession) -> None: """Process pending ExtractionQueue rows for Free-tier users. Each row corresponds to a stored episode that should be fed through the Mem0-style extraction pipeline. Rows are deleted after successful processing. Never raises — wraps in try/except. """ try: await _drain_extraction_queue_inner(db) except Exception as exc: logger.warning("memory_maintenance: drain_extraction_queue failed: %s", exc) async def _drain_extraction_queue_inner(db: AsyncSession) -> None: from app.models import ExtractionQueue # noqa: PLC0415 result = await db.execute(select(ExtractionQueue)) rows = result.scalars().all() if not rows: logger.debug("memory_maintenance: drain_extraction_queue nothing to drain") return logger.info("memory_maintenance: drain_extraction_queue pending=%d", len(rows)) from app.core.memory_extraction import run_extraction # noqa: PLC0415 processed = 0 for row in rows: try: await run_extraction( db=db, user_id=row.user_id, last_user_msg="", last_assistant_msg="", session_id=None, ) await db.delete(row) await db.commit() processed += 1 except Exception as exc: logger.warning( "memory_maintenance: drain failed row=%s user=%s: %s", row.id, row.user_id, exc, ) await db.rollback() logger.info("memory_maintenance: drain_extraction_queue processed=%d/%d", processed, len(rows)) async def mine_proactive_patterns(db: AsyncSession, user_id: str) -> None: """Mine recurring behavioral patterns from last 30 days of episodes (Power+ only). Steps: 1. Gate on proactive_mining tier feature. 2. Load + decrypt last 30 days of episodic summaries. 3. Call gpt-4o-mini to identify recurring patterns. 4. Encrypt and store each pattern in memory_proactive. 5. Apply decay to existing proactive rows. Never raises — wraps in try/except. """ try: await _mine_proactive_patterns_inner(db, user_id) except Exception as exc: logger.warning("memory_maintenance: mine_proactive_patterns failed user=%s: %s", user_id, exc) async def _mine_proactive_patterns_inner(db: AsyncSession, user_id: str) -> None: from app.billing.tier_manager import tier_manager # noqa: PLC0415 tier = await tier_manager.get_tier(user_id, db) if not tier_manager.check_feature(tier, "proactive_mining"): logger.debug("memory_maintenance: mine_proactive_patterns skipped (tier=%s)", tier) return # Load user Fernet key result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None or not user.encryption_key: logger.warning("memory_maintenance: mine_proactive_patterns no encryption_key user=%s", user_id) return fernet = Fernet(user.encryption_key.encode()) cutoff = datetime.now(timezone.utc) - timedelta(days=_MINING_LOOKBACK_DAYS) episodes_result = await db.execute( select(MemoryEpisodic) .where( MemoryEpisodic.user_id == user_id, MemoryEpisodic.created_at >= cutoff, ) .order_by(MemoryEpisodic.created_at.asc()) ) episode_rows = episodes_result.scalars().all() if len(episode_rows) < _MIN_EPISODES_FOR_MINING: logger.info( "memory_maintenance: mine_proactive_patterns skipped user=%s episodes=%d (< %d)", user_id, len(episode_rows), _MIN_EPISODES_FOR_MINING, ) return summaries: list[str] = [] for ep in episode_rows: try: plaintext = fernet.decrypt(ep.summary_encrypted.encode()).decode() summaries.append(plaintext) except Exception: pass if not summaries: return patterns = await _extract_proactive_patterns(summaries) if not patterns: logger.info("memory_maintenance: mine_proactive_patterns user=%s no patterns extracted", user_id) return stored = 0 for pattern_text in patterns: try: encrypted = fernet.encrypt(pattern_text.encode()).decode() row = MemoryProactive( id=str(uuid.uuid4()), user_id=user_id, pattern_encrypted=encrypted, confidence=0.7, source="inferred", ) db.add(row) stored += 1 except Exception as exc: logger.warning("memory_maintenance: failed to store pattern user=%s: %s", user_id, exc) try: await db.commit() logger.info( "memory_maintenance: mine_proactive_patterns user=%s stored=%d", user_id, stored, ) except Exception as exc: logger.warning("memory_maintenance: mine_proactive_patterns commit failed user=%s: %s", user_id, exc) await db.rollback() return await _decay_proactive_patterns(db, user_id, fernet) async def _extract_proactive_patterns(summaries: list[str]) -> list[str]: """Call memory-miner LLM to identify recurring behavioral/temporal patterns.""" from app.core.llm import get_agent_llm # noqa: PLC0415 llm = get_agent_llm("memory-miner", temperature=0) combined = "\n---\n".join(summaries[-20:]) # cap at last 20 to control token usage prompt = ( "You are analyzing conversation history for a personal AI secretary. " "Identify 3-5 recurring temporal or behavioral patterns (e.g. 'always works late on Thursdays', " "'prefers bullet-point summaries', 'frequently asks about Project Acme status'). " "Return each pattern as a plain, short English sentence on its own line. " "No numbering, no bullet points, no extra text.\n\n" f"Conversation history:\n{combined}" ) try: response = await llm.ainvoke(prompt) text = response.content if hasattr(response, "content") else str(response) lines = [line.strip() for line in str(text).splitlines() if line.strip()] return lines[:5] except Exception as exc: logger.warning("memory_maintenance: _extract_proactive_patterns LLM failed: %s", exc) return [] async def _decay_proactive_patterns(db: AsyncSession, user_id: str, fernet: Fernet) -> None: """Decay confidence of existing proactive patterns; prune below threshold.""" result = await db.execute( select(MemoryProactive).where(MemoryProactive.user_id == user_id) ) rows = result.scalars().all() now = datetime.now(timezone.utc) deleted = 0 decayed = 0 for row in rows: reference = row.created_at if reference is None: continue if reference.tzinfo is None: reference = reference.replace(tzinfo=timezone.utc) days_elapsed = (now - reference).days if days_elapsed < _PROACTIVE_DECAY_PERIOD_DAYS: continue periods = days_elapsed // _PROACTIVE_DECAY_PERIOD_DAYS new_confidence = row.confidence * (_PROACTIVE_DECAY_FACTOR ** periods) if new_confidence < _PROACTIVE_PRUNE_THRESHOLD: await db.delete(row) deleted += 1 else: row.confidence = new_confidence decayed += 1 try: await db.commit() logger.info( "memory_maintenance: decay_proactive user=%s decayed=%d deleted=%d", user_id, decayed, deleted, ) except Exception as exc: logger.warning("memory_maintenance: decay_proactive commit failed user=%s: %s", user_id, exc) await db.rollback() # ── Phase 7: weekly memory audit ────────────────────────────────────────────── _AUDIT_CONTRADICTIONS_FALLBACK = ( "You are auditing a personal AI assistant's memory bank. " "Each fact has an ID in brackets. " "Find pairs that directly contradict each other " "(e.g. 'prefers morning meetings' vs 'never schedules before noon'). " "For each contradiction, pick the ID to DELETE (the older or less specific one). " 'Return ONLY a valid JSON array, no markdown fences: ' '[{{"delete": "", "reason": ""}}]. ' "If no contradictions, return [].\n\n" "Facts:\n{facts}" ) _AUDIT_CANONICALIZE_FALLBACK = ( "You are auditing entity labels in a personal AI assistant's relational memory. " "These are names of people, companies, projects, or topics. " "Group labels that clearly refer to the same real-world entity " "(e.g. 'giulia', 'Giulia', 'Giulia R.' → canonical 'Giulia'). " "Return ONLY a valid JSON array, no markdown fences: " '[{{"canonical": "", "variants": ["", ""]}}]. ' "Only include groups with at least one variant. Singletons: omit.\n\n" "Labels:\n{labels}" ) async def audit_memory(db: AsyncSession, user_id: str) -> None: """Weekly audit: contradiction scan on associative facts + label canonicalization on relations. Steps: 1. Decrypt up to _AUDIT_MAX_FACTS associative rows; send list to memory-auditor LLM. 2. LLM flags rows to delete (direct contradictions); hard-delete them. 3. Collect unique subject/object labels from memory_relations; ask LLM to group duplicates. 4. Rewrite variant labels to their canonical form in-place. Never raises — wraps in try/except. """ try: await _audit_memory_inner(db, user_id) except Exception as exc: logger.warning("memory_maintenance: audit_memory failed user=%s: %s", user_id, exc) async def _audit_memory_inner(db: AsyncSession, user_id: str) -> None: result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None or not user.encryption_key: logger.warning("memory_maintenance: audit_memory no encryption_key user=%s", user_id) return fernet = Fernet(user.encryption_key.encode()) await _scan_associative_contradictions(db, user_id, fernet) await _canonicalize_relation_labels(db, user_id) async def _scan_associative_contradictions( db: AsyncSession, user_id: str, fernet: Fernet, ) -> None: """Decrypt associative facts, ask LLM to flag contradictions, delete superseded rows.""" result = await db.execute( select(MemoryAssociative) .where(MemoryAssociative.user_id == user_id) .order_by(MemoryAssociative.updated_at.desc()) .limit(_AUDIT_MAX_FACTS) ) rows = result.scalars().all() if len(rows) < 2: return id_to_text: dict[str, str] = {} for row in rows: try: plaintext = fernet.decrypt(row.content_encrypted.encode()).decode() id_to_text[row.id] = plaintext except Exception: pass if len(id_to_text) < 2: return id_list = list(id_to_text.keys()) numbered = "\n".join( f"{i + 1}. [{rid}] {id_to_text[rid]}" for i, rid in enumerate(id_list) ) template, prompt_obj = get_prompt_or_fallback( "memory_audit_contradictions", _AUDIT_CONTRADICTIONS_FALLBACK ) system_text = compile_prompt(template, prompt_obj, facts=numbered) from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415 from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415 llm = get_agent_llm("memory-auditor", temperature=0) lf = get_langfuse() messages = [ SystemMessage(content=system_text), HumanMessage(content="Audit facts for contradictions."), ] try: if lf: with lf.start_as_current_observation( as_type="generation", name="memory-audit-contradictions", model=model_for_agent("memory-auditor"), prompt=prompt_obj, input=messages, ) as gen: response = await llm.ainvoke(messages) gen.update(output=response.content, usage=extract_usage(response)) else: response = await llm.ainvoke(messages) text = response.content if hasattr(response, "content") else str(response) deletions = json.loads(text.strip()) if not isinstance(deletions, list): return except Exception as exc: logger.warning( "memory_maintenance: _scan_associative_contradictions LLM/parse failed user=%s: %s", user_id, exc, ) return deleted = 0 for item in deletions: if not isinstance(item, dict): continue rid = item.get("delete") if not rid or rid not in id_to_text: continue result2 = await db.execute( select(MemoryAssociative).where( MemoryAssociative.id == rid, MemoryAssociative.user_id == user_id, ) ) target = result2.scalar_one_or_none() if target: await db.delete(target) deleted += 1 logger.info( "memory_maintenance: audit deleted contradiction id=%s user=%s reason=%s", rid, user_id, item.get("reason", ""), ) if deleted: try: await db.commit() except Exception as exc: logger.warning( "memory_maintenance: audit contradiction commit failed user=%s: %s", user_id, exc ) await db.rollback() logger.info( "memory_maintenance: _scan_associative_contradictions user=%s deleted=%d", user_id, deleted ) async def _canonicalize_relation_labels(db: AsyncSession, user_id: str) -> None: """Group near-duplicate entity labels in memory_relations and unify to canonical form.""" result = await db.execute( select(MemoryRelation).where(MemoryRelation.user_id == user_id) ) rows = result.scalars().all() if not rows: return all_labels: set[str] = set() for row in rows: all_labels.add(row.subject_label) all_labels.add(row.object_label) labels_list = sorted(all_labels)[:_AUDIT_MAX_LABELS] if len(labels_list) < 2: return labels_block = "\n".join(f"- {lbl}" for lbl in labels_list) template, prompt_obj = get_prompt_or_fallback( "memory_audit_canonicalize", _AUDIT_CANONICALIZE_FALLBACK ) system_text = compile_prompt(template, prompt_obj, labels=labels_block) from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415 from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415 llm = get_agent_llm("memory-auditor", temperature=0) lf = get_langfuse() messages = [ SystemMessage(content=system_text), HumanMessage(content="Canonicalize entity labels."), ] try: if lf: with lf.start_as_current_observation( as_type="generation", name="memory-audit-canonicalize", model=model_for_agent("memory-auditor"), prompt=prompt_obj, input=messages, ) as gen: response = await llm.ainvoke(messages) gen.update(output=response.content, usage=extract_usage(response)) else: response = await llm.ainvoke(messages) text = response.content if hasattr(response, "content") else str(response) groups = json.loads(text.strip()) if not isinstance(groups, list): return except Exception as exc: logger.warning( "memory_maintenance: _canonicalize_relation_labels LLM/parse failed user=%s: %s", user_id, exc, ) return # Build variant → canonical map remap: dict[str, str] = {} for group in groups: if not isinstance(group, dict): continue canonical = group.get("canonical", "") variants = group.get("variants") or [] if not canonical: continue for v in variants: if isinstance(v, str) and v != canonical: remap[v] = canonical if not remap: return updated = 0 for row in rows: changed = False if row.subject_label in remap: row.subject_label = remap[row.subject_label] changed = True if row.object_label in remap: row.object_label = remap[row.object_label] changed = True if changed: updated += 1 if updated: try: await db.commit() logger.info( "memory_maintenance: _canonicalize_relation_labels user=%s updated=%d", user_id, updated, ) except Exception as exc: logger.warning( "memory_maintenance: canonicalize commit failed user=%s: %s", user_id, exc ) await db.rollback()