"""Mem0-style Extract/Update pipeline — Phase 2. Runs after every ``store_episode`` call to distil durable facts, preferences, routines, and relations from the latest conversation turn. Entry point: ``run_extraction(db, user_id, last_user_msg, last_assistant_msg, session_id)`` Design notes ------------ - Two gpt-4o-mini calls per turn: extract candidates, then decide action per candidate. - Short-circuit: if no existing neighbours → ADD without a second LLM call (cost saving). - Zero-trust: never logs decrypted user content; relation subject/object labels are treated as identifiers (safe to log per spec). - Must not raise into the request path — caller wraps in asyncio.create_task(). """ from __future__ import annotations import json import logging from typing import Any, Literal from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession from app.core.langfuse_client import get_langfuse, get_prompt_or_fallback, extract_usage, langfuse_context from app.core.llm import get_agent_llm, model_for_agent logger = logging.getLogger(__name__) # ── Fallback prompts (used when Langfuse unavailable) ───────────────────────── _EXTRACTION_FALLBACK = ( "You are a memory extractor for a personal AI secretary. Given the last conversation " "turn, the user's core memory, and recent episode summaries, identify durable facts, " "preferences, routines, and person/project relations worth remembering.\n\n" "Output JSON matching this schema exactly:\n" '{{"candidates": [{{"type": "", ' '"content": "", ' '"target_tier": "", ' '"subject": null, "predicate": null, "object": null, "confidence": 0.7}}]}}\n\n' "Rules:\n" "- Skip small talk, greetings, one-off questions.\n" "- Max 5 candidates per call.\n" "- Only extract durable information (still true next week).\n" "- For type=relation: subject/predicate/object required.\n" "- Default confidence=0.7.\n\n" "## Last turn\n{last_turn}\n\n" "## Core memory (current)\n{core_memory}\n\n" "## Recent episodes\n{recent_episodes}" ) _DECIDE_FALLBACK = ( "You are a memory update decision engine. Given a new memory candidate and a list of " "existing memories from the same tier, decide what action to take.\n\n" "Respond with exactly one word: ADD, UPDATE, DELETE, or NOOP.\n\n" "- ADD: new information not in existing memories.\n" "- UPDATE: contradicts or supersedes an existing memory.\n" "- DELETE: states something is no longer true.\n" "- NOOP: already captured accurately.\n\n" "## New candidate\n{candidate}\n\n" "## Existing memories (same tier, top neighbours)\n{existing_memories}" ) # ── Pydantic schemas ─────────────────────────────────────────────────────────── class MemoryCandidate(BaseModel): type: Literal["fact", "preference", "relation", "routine"] content: str target_tier: Literal["core", "associative", "relational", "proactive"] subject: str | None = None predicate: str | None = None object: str | None = None confidence: float = Field(default=0.7, ge=0.0, le=1.0) class ExtractionResult(BaseModel): candidates: list[MemoryCandidate] = Field(default_factory=list) # ── Task 2.1 — Extract candidates ───────────────────────────────────────────── async def extract_candidates( last_turn: str, core_memory: dict[str, str], recent_episodes: list[str], ) -> ExtractionResult: """Call gpt-4o-mini to extract memory candidates from the latest turn. Returns an ExtractionResult (may be empty on failure — never raises). """ core_str = "\n".join(f"{k}: {v}" for k, v in core_memory.items()) or "(empty)" episodes_str = "\n---\n".join(recent_episodes[-5:]) or "(none)" template, prompt_obj = get_prompt_or_fallback("memory_extraction", _EXTRACTION_FALLBACK) # Compile with Langfuse variable syntax ({{var}}) or fallback {var} if prompt_obj is not None: try: system_text = prompt_obj.compile( last_turn=last_turn, core_memory=core_str, recent_episodes=episodes_str, ) if isinstance(system_text, list): system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict)) except Exception as exc: logger.warning("memory_extraction: compile failed: %s", exc) system_text = template.format( last_turn=last_turn, core_memory=core_str, recent_episodes=episodes_str, ) else: system_text = template.format( last_turn=last_turn, core_memory=core_str, recent_episodes=episodes_str, ) llm = get_agent_llm("memory-extractor", temperature=0) # Bind JSON mode so the model always returns parseable output. llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined] lf = get_langfuse() try: from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415 messages = [ SystemMessage(content=system_text), HumanMessage(content="Extract memory candidates as JSON."), ] if lf: with lf.start_as_current_observation( as_type="generation", name="memory-extraction", model=model_for_agent("memory-extractor"), prompt=prompt_obj, input=messages, ) as gen: response = await llm_json.ainvoke(messages) gen.update(output=response.content, usage=extract_usage(response)) else: response = await llm_json.ainvoke(messages) raw = json.loads(response.content) result = ExtractionResult.model_validate(raw) logger.info("memory_extraction: extracted %d candidates", len(result.candidates)) return result except Exception as exc: logger.warning("memory_extraction: extract_candidates failed: %s", exc) return ExtractionResult(candidates=[]) # ── Task 2.2 — Decide action ────────────────────────────────────────────────── async def decide_action( candidate: MemoryCandidate, existing: list[str], ) -> Literal["ADD", "UPDATE", "DELETE", "NOOP"]: """Decide what to do with a candidate given existing memories in the same tier. Short-circuits to ADD without an LLM call when existing is empty (cost saving). Never raises. """ if not existing: return "ADD" candidate_str = f"[{candidate.type}] {candidate.content}" existing_str = "\n".join(f"- {m}" for m in existing) template, prompt_obj = get_prompt_or_fallback("memory_decide_action", _DECIDE_FALLBACK) if prompt_obj is not None: try: system_text = prompt_obj.compile( candidate=candidate_str, existing_memories=existing_str, ) if isinstance(system_text, list): system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict)) except Exception as exc: logger.warning("memory_extraction: decide compile failed: %s", exc) system_text = template.format(candidate=candidate_str, existing_memories=existing_str) else: system_text = template.format(candidate=candidate_str, existing_memories=existing_str) llm = get_agent_llm("memory-extractor", temperature=0) lf = get_langfuse() try: from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415 messages = [ SystemMessage(content=system_text), HumanMessage(content="Decide action."), ] if lf: with lf.start_as_current_observation( as_type="generation", name="memory-decide-action", model=model_for_agent("memory-extractor"), prompt=prompt_obj, input=messages, ) as gen: response = await llm.ainvoke(messages) gen.update(output=response.content, usage=extract_usage(response)) else: response = await llm.ainvoke(messages) verb = response.content.strip().upper() if verb in ("ADD", "UPDATE", "DELETE", "NOOP"): return verb # type: ignore[return-value] logger.warning("memory_extraction: unexpected decide verb=%r, defaulting ADD", verb) return "ADD" except Exception as exc: logger.warning("memory_extraction: decide_action failed: %s", exc) return "ADD" # ── Task 2.3 — Pipeline orchestrator ────────────────────────────────────────── async def run_extraction( db: AsyncSession, user_id: str, last_user_msg: str, last_assistant_msg: str, session_id: str | None, ) -> None: """Full Mem0-style extract/update pipeline for one conversation turn. Steps: 1. Load core memory + last 5 episodes. 2. extract_candidates() → up to 5 MemoryCandidate objects. 3. For each candidate: find top-3 neighbours → decide_action() → apply. 4. Trace via Langfuse. Never raises — wraps everything in try/except. """ try: await _run_extraction_inner(db, user_id, last_user_msg, last_assistant_msg, session_id) except Exception as exc: logger.warning("memory_extraction: run_extraction failed user=%s: %s", user_id, exc) async def _run_extraction_inner( db: AsyncSession, user_id: str, last_user_msg: str, last_assistant_msg: str, session_id: str | None, ) -> None: from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415 middleware = MemoryMiddleware(db) fernet = await middleware._get_fernet(user_id) if fernet is None: logger.warning("memory_extraction: no fernet for user=%s, skipping", user_id) return # 1. Load context core: dict[str, str] = await middleware._load_core(user_id, fernet) episodes: list[str] = await middleware._load_episodic(user_id, fernet, session_id=session_id) last_turn = f"User: {last_user_msg}\nAssistant: {last_assistant_msg}" lf = get_langfuse() async def _run(trace_id: str | None) -> dict[str, Any]: # 2. Extract candidates result = await extract_candidates(last_turn, core, episodes) if not result.candidates: logger.info("memory_extraction: no candidates user=%s", user_id) return {"candidates": 0, "applied": 0} logger.info( "memory_extraction: processing %d candidates user=%s trace=%s", len(result.candidates), user_id, trace_id or "-", ) # 3. Apply each candidate applied = 0 actions: list[str] = [] for candidate in result.candidates: try: await _apply_candidate(middleware, db, user_id, fernet, candidate, trace_id) applied += 1 actions.append(f"{candidate.type}:{candidate.target_tier}") except Exception as exc: logger.warning( "memory_extraction: apply failed candidate=%r user=%s: %s", candidate.content[:80], user_id, exc, ) logger.info( "memory_extraction: applied %d/%d candidates user=%s", applied, len(result.candidates), user_id, ) return {"candidates": len(result.candidates), "applied": applied, "actions": actions} with langfuse_context(user_id=user_id, session_id=session_id): if lf: with lf.start_as_current_observation( as_type="span", name="memory-extraction-pipeline", input={"last_turn_preview": last_turn[:200]}, ) as span: summary = await _run(trace_id=span.id) span.update(output=summary) try: lf.flush() except Exception: pass else: await _run(trace_id=None) async def _apply_candidate( middleware: Any, db: AsyncSession, user_id: str, fernet: Any, candidate: MemoryCandidate, trace_id: str | None, ) -> None: """Fetch neighbours, decide action, apply to the appropriate tier.""" neighbours: list[str] = [] if candidate.target_tier == "core": # For core tier: neighbours are existing core block values for similar keys. blocks = await middleware.list_core_blocks(user_id) neighbours = [b["value"] for b in blocks[:3]] elif candidate.target_tier == "associative": neighbours = await middleware.search_archival(user_id, candidate.content, top_k=3) elif candidate.target_tier == "relational": # Relation candidates handled specially — passed to upsert_relation directly. # Neighbours: search by subject label if available. neighbours = [] elif candidate.target_tier == "proactive": neighbours = await middleware.search_recall(user_id, candidate.content, top_k=3) action = await decide_action(candidate, neighbours) logger.info( "memory_extraction: candidate type=%s tier=%s action=%s", candidate.type, candidate.target_tier, action, ) if action == "NOOP": return if candidate.target_tier == "relational": # Always upsert relations — decide_action skipped (no neighbour search). if candidate.subject and candidate.predicate and candidate.object: await _upsert_relation_stub( middleware, db, user_id, candidate, trace_id ) return if action in ("ADD", "UPDATE"): if candidate.target_tier == "core": # Derive a short key from the content (first 40 chars, snake_cased). key = _content_to_key(candidate.content) await middleware.update_core(user_id, key, candidate.content, trace_id=trace_id) elif candidate.target_tier == "associative": await middleware.store_associative(user_id, candidate.content) elif candidate.target_tier == "proactive": await _store_proactive_stub(middleware, db, user_id, candidate, fernet) elif action == "DELETE": if candidate.target_tier == "core": key = _content_to_key(candidate.content) await middleware.delete_core(user_id, key) def _content_to_key(content: str) -> str: """Derive a short snake_case key from a content string (first 40 chars).""" import re # noqa: PLC0415 slug = re.sub(r"[^a-z0-9]+", "_", content[:40].lower()).strip("_") return slug or "memory" async def _upsert_relation_stub( middleware: Any, db: AsyncSession, user_id: str, candidate: MemoryCandidate, trace_id: str | None, ) -> None: """Stub: upsert_relation will be fully wired in Phase 3. Called here so Phase 2 extraction pipeline already routes relation candidates correctly. Phase 3 replaces this with MemoryMiddleware.upsert_relation(). """ if hasattr(middleware, "upsert_relation"): await middleware.upsert_relation( user_id=user_id, subject=candidate.subject, subject_type="unknown", predicate=candidate.predicate, object_=candidate.object, object_type="unknown", confidence=candidate.confidence, ) else: logger.info( "memory_extraction: relation stub (Phase 3 not yet wired) subject=%s predicate=%s object=%s", candidate.subject, candidate.predicate, candidate.object, ) async def _store_proactive_stub( middleware: Any, db: AsyncSession, user_id: str, candidate: MemoryCandidate, fernet: Any, ) -> None: """Store a proactive pattern row directly (MemoryProactive model).""" import uuid # noqa: PLC0415 from app.models import MemoryProactive # noqa: PLC0415 from app.core.memory_middleware import _encrypt # noqa: PLC0415 encrypted = _encrypt(fernet, candidate.content) row = MemoryProactive( id=str(uuid.uuid4()), user_id=user_id, pattern_encrypted=encrypted, confidence=candidate.confidence, source="inferred", ) db.add(row) try: await db.commit() logger.info("memory_extraction: stored proactive pattern user=%s", user_id) except Exception as exc: logger.warning("memory_extraction: store proactive failed: %s", exc) await db.rollback()