451 lines
17 KiB
Python
451 lines
17 KiB
Python
"""Mem0-style Extract/Update pipeline — Phase 2.
|
|
|
|
Runs after every ``store_episode`` call to distil durable facts, preferences,
|
|
routines, and relations from the latest conversation turn.
|
|
|
|
Entry point: ``run_extraction(db, user_id, last_user_msg, last_assistant_msg, session_id)``
|
|
|
|
Design notes
|
|
------------
|
|
- Two gpt-4o-mini calls per turn: extract candidates, then decide action per candidate.
|
|
- Short-circuit: if no existing neighbours → ADD without a second LLM call (cost saving).
|
|
- Zero-trust: never logs decrypted user content; relation subject/object labels are
|
|
treated as identifiers (safe to log per spec).
|
|
- Must not raise into the request path — caller wraps in asyncio.create_task().
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from typing import Any, Literal
|
|
|
|
from pydantic import BaseModel, Field
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.langfuse_client import get_langfuse, get_prompt_or_fallback, extract_usage, langfuse_context
|
|
from app.core.llm import get_agent_llm, model_for_agent
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ── Fallback prompts (used when Langfuse unavailable) ─────────────────────────
|
|
|
|
_EXTRACTION_FALLBACK = (
|
|
"You are a memory extractor for a personal AI secretary. Given the last conversation "
|
|
"turn, the user's core memory, and recent episode summaries, identify durable facts, "
|
|
"preferences, routines, and person/project relations worth remembering.\n\n"
|
|
"Output JSON matching this schema exactly:\n"
|
|
'{{"candidates": [{{"type": "<fact|preference|relation|routine>", '
|
|
'"content": "<short canonical statement>", '
|
|
'"target_tier": "<core|associative|relational|proactive>", '
|
|
'"subject": null, "predicate": null, "object": null, "confidence": 0.7}}]}}\n\n'
|
|
"Rules:\n"
|
|
"- Skip small talk, greetings, one-off questions.\n"
|
|
"- Max 5 candidates per call.\n"
|
|
"- Only extract durable information (still true next week).\n"
|
|
"- For type=relation: subject/predicate/object required.\n"
|
|
"- Default confidence=0.7.\n\n"
|
|
"## Last turn\n{last_turn}\n\n"
|
|
"## Core memory (current)\n{core_memory}\n\n"
|
|
"## Recent episodes\n{recent_episodes}"
|
|
)
|
|
|
|
_DECIDE_FALLBACK = (
|
|
"You are a memory update decision engine. Given a new memory candidate and a list of "
|
|
"existing memories from the same tier, decide what action to take.\n\n"
|
|
"Respond with exactly one word: ADD, UPDATE, DELETE, or NOOP.\n\n"
|
|
"- ADD: new information not in existing memories.\n"
|
|
"- UPDATE: contradicts or supersedes an existing memory.\n"
|
|
"- DELETE: states something is no longer true.\n"
|
|
"- NOOP: already captured accurately.\n\n"
|
|
"## New candidate\n{candidate}\n\n"
|
|
"## Existing memories (same tier, top neighbours)\n{existing_memories}"
|
|
)
|
|
|
|
|
|
# ── Pydantic schemas ───────────────────────────────────────────────────────────
|
|
|
|
class MemoryCandidate(BaseModel):
|
|
type: Literal["fact", "preference", "relation", "routine"]
|
|
content: str
|
|
target_tier: Literal["core", "associative", "relational", "proactive"]
|
|
subject: str | None = None
|
|
predicate: str | None = None
|
|
object: str | None = None
|
|
confidence: float = Field(default=0.7, ge=0.0, le=1.0)
|
|
|
|
|
|
class ExtractionResult(BaseModel):
|
|
candidates: list[MemoryCandidate] = Field(default_factory=list)
|
|
|
|
|
|
# ── Task 2.1 — Extract candidates ─────────────────────────────────────────────
|
|
|
|
async def extract_candidates(
|
|
last_turn: str,
|
|
core_memory: dict[str, str],
|
|
recent_episodes: list[str],
|
|
) -> ExtractionResult:
|
|
"""Call gpt-4o-mini to extract memory candidates from the latest turn.
|
|
|
|
Returns an ExtractionResult (may be empty on failure — never raises).
|
|
"""
|
|
core_str = "\n".join(f"{k}: {v}" for k, v in core_memory.items()) or "(empty)"
|
|
episodes_str = "\n---\n".join(recent_episodes[-5:]) or "(none)"
|
|
|
|
template, prompt_obj = get_prompt_or_fallback("memory_extraction", _EXTRACTION_FALLBACK)
|
|
|
|
# Compile with Langfuse variable syntax ({{var}}) or fallback {var}
|
|
if prompt_obj is not None:
|
|
try:
|
|
system_text = prompt_obj.compile(
|
|
last_turn=last_turn,
|
|
core_memory=core_str,
|
|
recent_episodes=episodes_str,
|
|
)
|
|
if isinstance(system_text, list):
|
|
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
|
except Exception as exc:
|
|
logger.warning("memory_extraction: compile failed: %s", exc)
|
|
system_text = template.format(
|
|
last_turn=last_turn,
|
|
core_memory=core_str,
|
|
recent_episodes=episodes_str,
|
|
)
|
|
else:
|
|
system_text = template.format(
|
|
last_turn=last_turn,
|
|
core_memory=core_str,
|
|
recent_episodes=episodes_str,
|
|
)
|
|
|
|
llm = get_agent_llm("memory-extractor", temperature=0)
|
|
# Bind JSON mode so the model always returns parseable output.
|
|
llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined]
|
|
|
|
lf = get_langfuse()
|
|
try:
|
|
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
|
messages = [
|
|
SystemMessage(content=system_text),
|
|
HumanMessage(content="Extract memory candidates as JSON."),
|
|
]
|
|
|
|
if lf:
|
|
with lf.start_as_current_observation(
|
|
as_type="generation",
|
|
name="memory-extraction",
|
|
model=model_for_agent("memory-extractor"),
|
|
prompt=prompt_obj,
|
|
input=messages,
|
|
) as gen:
|
|
response = await llm_json.ainvoke(messages)
|
|
gen.update(output=response.content, usage=extract_usage(response))
|
|
else:
|
|
response = await llm_json.ainvoke(messages)
|
|
|
|
raw = json.loads(response.content)
|
|
result = ExtractionResult.model_validate(raw)
|
|
logger.info("memory_extraction: extracted %d candidates", len(result.candidates))
|
|
return result
|
|
|
|
except Exception as exc:
|
|
logger.warning("memory_extraction: extract_candidates failed: %s", exc)
|
|
return ExtractionResult(candidates=[])
|
|
|
|
|
|
# ── Task 2.2 — Decide action ──────────────────────────────────────────────────
|
|
|
|
async def decide_action(
|
|
candidate: MemoryCandidate,
|
|
existing: list[str],
|
|
) -> Literal["ADD", "UPDATE", "DELETE", "NOOP"]:
|
|
"""Decide what to do with a candidate given existing memories in the same tier.
|
|
|
|
Short-circuits to ADD without an LLM call when existing is empty (cost saving).
|
|
Never raises.
|
|
"""
|
|
if not existing:
|
|
return "ADD"
|
|
|
|
candidate_str = f"[{candidate.type}] {candidate.content}"
|
|
existing_str = "\n".join(f"- {m}" for m in existing)
|
|
|
|
template, prompt_obj = get_prompt_or_fallback("memory_decide_action", _DECIDE_FALLBACK)
|
|
|
|
if prompt_obj is not None:
|
|
try:
|
|
system_text = prompt_obj.compile(
|
|
candidate=candidate_str,
|
|
existing_memories=existing_str,
|
|
)
|
|
if isinstance(system_text, list):
|
|
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
|
except Exception as exc:
|
|
logger.warning("memory_extraction: decide compile failed: %s", exc)
|
|
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
|
else:
|
|
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
|
|
|
llm = get_agent_llm("memory-extractor", temperature=0)
|
|
lf = get_langfuse()
|
|
|
|
try:
|
|
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
|
messages = [
|
|
SystemMessage(content=system_text),
|
|
HumanMessage(content="Decide action."),
|
|
]
|
|
|
|
if lf:
|
|
with lf.start_as_current_observation(
|
|
as_type="generation",
|
|
name="memory-decide-action",
|
|
model=model_for_agent("memory-extractor"),
|
|
prompt=prompt_obj,
|
|
input=messages,
|
|
) as gen:
|
|
response = await llm.ainvoke(messages)
|
|
gen.update(output=response.content, usage=extract_usage(response))
|
|
else:
|
|
response = await llm.ainvoke(messages)
|
|
|
|
verb = response.content.strip().upper()
|
|
if verb in ("ADD", "UPDATE", "DELETE", "NOOP"):
|
|
return verb # type: ignore[return-value]
|
|
logger.warning("memory_extraction: unexpected decide verb=%r, defaulting ADD", verb)
|
|
return "ADD"
|
|
|
|
except Exception as exc:
|
|
logger.warning("memory_extraction: decide_action failed: %s", exc)
|
|
return "ADD"
|
|
|
|
|
|
# ── Task 2.3 — Pipeline orchestrator ──────────────────────────────────────────
|
|
|
|
async def run_extraction(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
last_user_msg: str,
|
|
last_assistant_msg: str,
|
|
session_id: str | None,
|
|
) -> None:
|
|
"""Full Mem0-style extract/update pipeline for one conversation turn.
|
|
|
|
Steps:
|
|
1. Load core memory + last 5 episodes.
|
|
2. extract_candidates() → up to 5 MemoryCandidate objects.
|
|
3. For each candidate: find top-3 neighbours → decide_action() → apply.
|
|
4. Trace via Langfuse.
|
|
|
|
Never raises — wraps everything in try/except.
|
|
"""
|
|
try:
|
|
await _run_extraction_inner(db, user_id, last_user_msg, last_assistant_msg, session_id)
|
|
except Exception as exc:
|
|
logger.warning("memory_extraction: run_extraction failed user=%s: %s", user_id, exc)
|
|
|
|
|
|
async def _run_extraction_inner(
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
last_user_msg: str,
|
|
last_assistant_msg: str,
|
|
session_id: str | None,
|
|
) -> None:
|
|
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
|
|
|
|
middleware = MemoryMiddleware(db)
|
|
fernet = await middleware._get_fernet(user_id)
|
|
if fernet is None:
|
|
logger.warning("memory_extraction: no fernet for user=%s, skipping", user_id)
|
|
return
|
|
|
|
# 1. Load context
|
|
core: dict[str, str] = await middleware._load_core(user_id, fernet)
|
|
episodes: list[str] = await middleware._load_episodic(user_id, fernet, session_id=session_id)
|
|
|
|
last_turn = f"User: {last_user_msg}\nAssistant: {last_assistant_msg}"
|
|
|
|
lf = get_langfuse()
|
|
|
|
async def _run(trace_id: str | None) -> dict[str, Any]:
|
|
# 2. Extract candidates
|
|
result = await extract_candidates(last_turn, core, episodes)
|
|
if not result.candidates:
|
|
logger.info("memory_extraction: no candidates user=%s", user_id)
|
|
return {"candidates": 0, "applied": 0}
|
|
|
|
logger.info(
|
|
"memory_extraction: processing %d candidates user=%s trace=%s",
|
|
len(result.candidates),
|
|
user_id,
|
|
trace_id or "-",
|
|
)
|
|
|
|
# 3. Apply each candidate
|
|
applied = 0
|
|
actions: list[str] = []
|
|
for candidate in result.candidates:
|
|
try:
|
|
await _apply_candidate(middleware, db, user_id, fernet, candidate, trace_id)
|
|
applied += 1
|
|
actions.append(f"{candidate.type}:{candidate.target_tier}")
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"memory_extraction: apply failed candidate=%r user=%s: %s",
|
|
candidate.content[:80],
|
|
user_id,
|
|
exc,
|
|
)
|
|
|
|
logger.info(
|
|
"memory_extraction: applied %d/%d candidates user=%s",
|
|
applied,
|
|
len(result.candidates),
|
|
user_id,
|
|
)
|
|
return {"candidates": len(result.candidates), "applied": applied, "actions": actions}
|
|
|
|
with langfuse_context(user_id=user_id, session_id=session_id):
|
|
if lf:
|
|
with lf.start_as_current_observation(
|
|
as_type="span",
|
|
name="memory-extraction-pipeline",
|
|
input={"last_turn_preview": last_turn[:200]},
|
|
) as span:
|
|
summary = await _run(trace_id=span.id)
|
|
span.update(output=summary)
|
|
try:
|
|
lf.flush()
|
|
except Exception:
|
|
pass
|
|
else:
|
|
await _run(trace_id=None)
|
|
|
|
|
|
async def _apply_candidate(
|
|
middleware: Any,
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
fernet: Any,
|
|
candidate: MemoryCandidate,
|
|
trace_id: str | None,
|
|
) -> None:
|
|
"""Fetch neighbours, decide action, apply to the appropriate tier."""
|
|
|
|
neighbours: list[str] = []
|
|
|
|
if candidate.target_tier == "core":
|
|
# For core tier: neighbours are existing core block values for similar keys.
|
|
blocks = await middleware.list_core_blocks(user_id)
|
|
neighbours = [b["value"] for b in blocks[:3]]
|
|
|
|
elif candidate.target_tier == "associative":
|
|
neighbours = await middleware.search_archival(user_id, candidate.content, top_k=3)
|
|
|
|
elif candidate.target_tier == "relational":
|
|
# Relation candidates handled specially — passed to upsert_relation directly.
|
|
# Neighbours: search by subject label if available.
|
|
neighbours = []
|
|
|
|
elif candidate.target_tier == "proactive":
|
|
neighbours = await middleware.search_recall(user_id, candidate.content, top_k=3)
|
|
|
|
action = await decide_action(candidate, neighbours)
|
|
logger.info(
|
|
"memory_extraction: candidate type=%s tier=%s action=%s",
|
|
candidate.type,
|
|
candidate.target_tier,
|
|
action,
|
|
)
|
|
|
|
if action == "NOOP":
|
|
return
|
|
|
|
if candidate.target_tier == "relational":
|
|
# Always upsert relations — decide_action skipped (no neighbour search).
|
|
if candidate.subject and candidate.predicate and candidate.object:
|
|
await _upsert_relation(
|
|
middleware, db, user_id, candidate, trace_id
|
|
)
|
|
return
|
|
|
|
if action in ("ADD", "UPDATE"):
|
|
if candidate.target_tier == "core":
|
|
# Derive a short key from the content (first 40 chars, snake_cased).
|
|
key = _content_to_key(candidate.content)
|
|
await middleware.update_core(user_id, key, candidate.content, trace_id=trace_id)
|
|
|
|
elif candidate.target_tier == "associative":
|
|
await middleware.store_associative(user_id, candidate.content)
|
|
|
|
elif candidate.target_tier == "proactive":
|
|
await _store_proactive_stub(middleware, db, user_id, candidate, fernet)
|
|
|
|
elif action == "DELETE":
|
|
if candidate.target_tier == "core":
|
|
key = _content_to_key(candidate.content)
|
|
await middleware.delete_core(user_id, key)
|
|
|
|
|
|
def _content_to_key(content: str) -> str:
|
|
"""Derive a short snake_case key from a content string (first 40 chars)."""
|
|
import re # noqa: PLC0415
|
|
slug = re.sub(r"[^a-z0-9]+", "_", content[:40].lower()).strip("_")
|
|
return slug or "memory"
|
|
|
|
|
|
async def _upsert_relation(
|
|
middleware: Any,
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
candidate: MemoryCandidate,
|
|
trace_id: str | None,
|
|
) -> None:
|
|
"""Upsert a relation row via MemoryMiddleware.upsert_relation (Phase 3)."""
|
|
await middleware.upsert_relation(
|
|
user_id=user_id,
|
|
subject=candidate.subject or "unknown",
|
|
subject_type="unknown",
|
|
predicate=candidate.predicate or "related_to",
|
|
object_=candidate.object or "unknown",
|
|
object_type="unknown",
|
|
confidence=candidate.confidence,
|
|
)
|
|
logger.info(
|
|
"memory_extraction: upserted relation subject=%s predicate=%s object=%s",
|
|
candidate.subject,
|
|
candidate.predicate,
|
|
candidate.object,
|
|
)
|
|
|
|
|
|
async def _store_proactive_stub(
|
|
middleware: Any,
|
|
db: AsyncSession,
|
|
user_id: str,
|
|
candidate: MemoryCandidate,
|
|
fernet: Any,
|
|
) -> None:
|
|
"""Store a proactive pattern row directly (MemoryProactive model)."""
|
|
import uuid # noqa: PLC0415
|
|
from app.models import MemoryProactive # noqa: PLC0415
|
|
from app.core.memory_middleware import _encrypt # noqa: PLC0415
|
|
|
|
encrypted = _encrypt(fernet, candidate.content)
|
|
row = MemoryProactive(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
pattern_encrypted=encrypted,
|
|
confidence=candidate.confidence,
|
|
source="inferred",
|
|
)
|
|
db.add(row)
|
|
try:
|
|
await db.commit()
|
|
logger.info("memory_extraction: stored proactive pattern user=%s", user_id)
|
|
except Exception as exc:
|
|
logger.warning("memory_extraction: store proactive failed: %s", exc)
|
|
await db.rollback()
|